You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bias_multi_thread.cpp 47 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204
  1. /**
  2. * \file dnn/test/arm_common/conv_bias_multi_thread.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/dtype.h"
  13. #include "test/arm_common/fixture.h"
  14. #include "test/common/benchmarker.h"
  15. #include "test/common/conv_bias.h"
  16. #include "test/arm_common/cpuinfo_help.h"
  17. using namespace megdnn;
  18. using namespace test;
  19. using namespace conv_bias;
  20. std::vector<conv_bias::TestArg> get_int8_quint8_conv_bias_args(
  21. std::vector<size_t> kernel, size_t stride, bool no_pad, bool no_bias,
  22. bool no_nonlinemode) {
  23. using namespace conv_bias;
  24. using Param = param::ConvBias;
  25. using NLMode = param::ConvBias::NonlineMode;
  26. std::vector<TestArg> args;
  27. auto pack = [&](size_t n, size_t oc, size_t ic, size_t w, size_t h, size_t kernel,
  28. size_t stride, NLMode nlmode) {
  29. Param param;
  30. param.stride_h = stride;
  31. param.stride_w = stride;
  32. if (!no_pad) {
  33. param.pad_h = kernel / 2;
  34. param.pad_w = kernel / 2;
  35. } else {
  36. param.pad_h = 0;
  37. param.pad_w = 0;
  38. }
  39. param.nonlineMode = nlmode;
  40. args.emplace_back(
  41. param, TensorShape{n, ic, h, w}, TensorShape{oc, ic, kernel, kernel},
  42. TensorShape{});
  43. if (!no_bias) {
  44. args.emplace_back(
  45. param, TensorShape{n, ic, h, w},
  46. TensorShape{oc, ic, kernel, kernel}, TensorShape{1, oc, 1, 1});
  47. }
  48. };
  49. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  50. if (!no_nonlinemode) {
  51. nonlinemode.emplace_back(NLMode::RELU);
  52. nonlinemode.emplace_back(NLMode::H_SWISH);
  53. }
  54. for (size_t n : {1, 2}) {
  55. for (auto nlmode : nonlinemode) {
  56. for (size_t ic : {1, 3, 7}) {
  57. for (size_t oc : {1, 3, 7}) {
  58. for (size_t size : {4, 6, 8, 14, 16, 18}) {
  59. for (size_t kern : kernel) {
  60. pack(n, oc, ic, size, size, kern, stride, nlmode);
  61. }
  62. }
  63. }
  64. }
  65. }
  66. }
  67. return args;
  68. }
  69. std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args(
  70. std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode,
  71. bool no_full_bias) {
  72. using namespace conv_bias;
  73. using Param = param::ConvBias;
  74. using NLMode = param::ConvBias::NonlineMode;
  75. std::vector<TestArg> args;
  76. auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
  77. size_t stride, NLMode nlmode, bool pad) {
  78. Param param;
  79. param.stride_h = stride;
  80. param.stride_w = stride;
  81. if (pad) {
  82. param.pad_h = kernel / 2;
  83. param.pad_w = kernel / 2;
  84. } else {
  85. param.pad_h = 0;
  86. param.pad_w = 0;
  87. }
  88. param.nonlineMode = nlmode;
  89. param.format = param::ConvBias::Format::NCHW44;
  90. param.sparse = param::ConvBias::Sparse::GROUP;
  91. args.emplace_back(
  92. param, TensorShape{n, group, h, w, 4},
  93. TensorShape{group, 1, 1, kernel, kernel, 4}, TensorShape{});
  94. if (!no_bias) {
  95. args.emplace_back(
  96. param, TensorShape{n, group, h, w, 4},
  97. TensorShape{group, 1, 1, kernel, kernel, 4},
  98. TensorShape{1, group, 1, 1, 4});
  99. }
  100. if (!no_full_bias) {
  101. args.emplace_back(
  102. param, TensorShape{n, group, h, w, 4},
  103. TensorShape{group, 1, 1, kernel, kernel, 4},
  104. TensorShape{
  105. n, group, (h + 2 * param.pad_w - kernel) / stride + 1,
  106. (w + 2 * param.pad_w - kernel) / stride + 1, 4});
  107. }
  108. };
  109. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  110. if (!no_nonlinemode) {
  111. nonlinemode.emplace_back(NLMode::RELU);
  112. nonlinemode.emplace_back(NLMode::H_SWISH);
  113. }
  114. for (size_t n : {1, 2}) {
  115. for (auto nlmode : nonlinemode) {
  116. for (bool pad : {true}) {
  117. for (size_t group : {1, 2, 4, 7, 16}) {
  118. for (size_t size : {4, 6, 7, 9, 20}) {
  119. for (size_t kern : kernel) {
  120. pack(n, group, size, size, kern, stride, nlmode, pad);
  121. }
  122. }
  123. }
  124. }
  125. for (bool pad : {false}) {
  126. for (size_t group : {1, 2, 7, 16}) {
  127. for (size_t size : {7, 9, 20}) {
  128. for (size_t kern : kernel) {
  129. pack(n, group, size, size, kern, stride, nlmode, pad);
  130. }
  131. }
  132. }
  133. }
  134. }
  135. }
  136. return args;
  137. }
  138. std::vector<conv_bias::TestArg> get_nchw88_channel_wise_args(
  139. std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode,
  140. bool no_full_bias) {
  141. using namespace conv_bias;
  142. using Param = param::ConvBias;
  143. using NLMode = param::ConvBias::NonlineMode;
  144. std::vector<TestArg> args;
  145. auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
  146. size_t stride, NLMode nlmode, bool pad) {
  147. Param param;
  148. param.stride_h = stride;
  149. param.stride_w = stride;
  150. if (pad) {
  151. param.pad_h = kernel / 2;
  152. param.pad_w = kernel / 2;
  153. } else {
  154. param.pad_h = 0;
  155. param.pad_w = 0;
  156. }
  157. param.nonlineMode = nlmode;
  158. param.format = param::ConvBias::Format::NCHW88;
  159. param.sparse = param::ConvBias::Sparse::GROUP;
  160. args.emplace_back(
  161. param, TensorShape{n, group, h, w, 8},
  162. TensorShape{group, 1, 1, kernel, kernel, 8}, TensorShape{});
  163. if (!no_bias) {
  164. args.emplace_back(
  165. param, TensorShape{n, group, h, w, 8},
  166. TensorShape{group, 1, 1, kernel, kernel, 8},
  167. TensorShape{1, group, 1, 1, 8});
  168. }
  169. if (!no_full_bias) {
  170. args.emplace_back(
  171. param, TensorShape{n, group, h, w, 8},
  172. TensorShape{group, 1, 1, kernel, kernel, 8},
  173. TensorShape{
  174. n, group, (h + 2 * param.pad_w - kernel) / stride + 1,
  175. (w + 2 * param.pad_w - kernel) / stride + 1, 8});
  176. }
  177. };
  178. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  179. if (!no_nonlinemode) {
  180. nonlinemode.emplace_back(NLMode::RELU);
  181. nonlinemode.emplace_back(NLMode::H_SWISH);
  182. }
  183. for (size_t n : {1, 2}) {
  184. for (auto nlmode : nonlinemode) {
  185. for (bool pad : {true}) {
  186. for (size_t group : {1, 2, 4, 7, 8, 16}) {
  187. for (size_t size : {4, 6, 7, 9, 20}) {
  188. for (size_t kern : kernel) {
  189. pack(n, group, size, size, kern, stride, nlmode, pad);
  190. }
  191. }
  192. }
  193. }
  194. for (bool pad : {false}) {
  195. for (size_t group : {1, 2, 7, 16}) {
  196. for (size_t size : {7, 9, 20}) {
  197. for (size_t kern : kernel) {
  198. pack(n, group, size, size, kern, stride, nlmode, pad);
  199. }
  200. }
  201. }
  202. }
  203. }
  204. }
  205. return args;
  206. }
  207. void checker_conv_bias_qint8x8x8(
  208. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name) {
  209. Checker<ConvBias> checker(handle);
  210. checker.set_before_exec_callback(
  211. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  212. #if MEGDNN_ARMV7
  213. checker.set_epsilon(1);
  214. #endif
  215. UniformIntRNG rng{-50, 50};
  216. checker.set_dtype(0, dtype::QuantizedS8(0.41113496f))
  217. .set_dtype(1, dtype::QuantizedS8(0.01887994f))
  218. .set_dtype(2, dtype::QuantizedS32(0.41113496f * 0.01887994f))
  219. .set_dtype(4, dtype::QuantizedS8(0.49550694f))
  220. .set_rng(0, &rng)
  221. .set_rng(1, &rng)
  222. .set_rng(2, &rng);
  223. for (auto&& arg : args) {
  224. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  225. }
  226. }
  227. void checker_conv_bias_qint8x8x32(
  228. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name) {
  229. Checker<ConvBias> checker(handle);
  230. UniformIntRNG rng{-50, 50};
  231. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  232. .set_dtype(1, dtype::QuantizedS8(2.5f))
  233. .set_dtype(2, dtype::QuantizedS32(6.25f))
  234. .set_dtype(4, {});
  235. checker.set_before_exec_callback(
  236. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  237. for (auto&& arg : args) {
  238. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  239. }
  240. }
  241. void checker_conv_bias_quint8x8x8(
  242. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name) {
  243. Checker<ConvBias> checker(handle);
  244. checker.set_before_exec_callback(
  245. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  246. UniformIntRNG rng(0, 255);
  247. checker.set_dtype(0, dtype::Quantized8Asymm(0.2f, 100))
  248. .set_dtype(1, dtype::Quantized8Asymm(0.2f, 120))
  249. .set_dtype(2, dtype::QuantizedS32(0.04f))
  250. .set_dtype(4, dtype::Quantized8Asymm(1.4f, 110))
  251. .set_rng(0, &rng)
  252. .set_rng(1, &rng)
  253. .set_rng(2, &rng);
  254. for (auto&& arg : args) {
  255. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  256. }
  257. }
  258. void checker_conv_bias_quint8x8x32(
  259. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name) {
  260. Checker<ConvBias> checker(handle);
  261. checker.set_before_exec_callback(
  262. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  263. NormalRNG rng(128.f);
  264. checker.set_rng(0, &rng).set_rng(1, &rng);
  265. checker.set_dtype(0, dtype::Quantized8Asymm(1.2f, (uint8_t)127))
  266. .set_dtype(1, dtype::Quantized8Asymm(1.3f, (uint8_t)129))
  267. .set_dtype(2, dtype::QuantizedS32(1.2 * 1.3))
  268. .set_dtype(4, {});
  269. for (auto&& arg : args) {
  270. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  271. }
  272. }
  273. void checker_conv_bias_int8x8x32_multi(
  274. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name) {
  275. Checker<ConvBias> checker(handle);
  276. checker.set_before_exec_callback(
  277. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  278. checker.set_dtype(0, dtype::Int8());
  279. checker.set_dtype(1, dtype::Int8());
  280. checker.set_dtype(2, dtype::Int32());
  281. checker.set_dtype(4, dtype::Int32());
  282. for (auto&& arg : args) {
  283. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  284. }
  285. }
  286. /**********************************F32 direct************************/
  287. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32) {
  288. check_conv_bias(
  289. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), handle(),
  290. "F32DIRECT");
  291. }
  292. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) {
  293. //! k=7 s=1
  294. check_conv_bias(
  295. get_nchw44_conv_bias_args({7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1),
  296. handle(), "F32_CONV_NCHW44_DIRECT");
  297. }
  298. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K2K3) {
  299. check_conv_bias(
  300. get_nchw44_conv_bias_args({2, 3}, FULL_NLMODE, ONLY_BR_BIASMODE, 1),
  301. handle(), "F32_CONV_NCHW44_DIRECT");
  302. }
  303. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K5) {
  304. check_conv_bias(
  305. get_nchw44_conv_bias_args({5}, FULL_NLMODE, ONLY_BR_BIASMODE, 1), handle(),
  306. "F32_CONV_NCHW44_DIRECT");
  307. }
  308. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) {
  309. check_conv_bias(
  310. get_nchw44_conv_bias_args({2, 3, 5, 7}, FULL_NLMODE, ONLY_BR_BIASMODE, 2),
  311. handle(), "F32_CONV_NCHW44_DIRECT");
  312. }
  313. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1) {
  314. check_conv_bias(
  315. get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), handle(),
  316. "F32STRD1");
  317. }
  318. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2) {
  319. check_conv_bias(
  320. get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), handle(),
  321. "F32STRD2");
  322. }
  323. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S2) {
  324. check_conv_bias(
  325. get_nchw44_conv_bias_args(
  326. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 2, false,
  327. true),
  328. handle(), "F32_CONV_NCHW_NCHW44");
  329. }
  330. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S1) {
  331. check_conv_bias(
  332. get_nchw44_conv_bias_args(
  333. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 1, false,
  334. true),
  335. handle(), "F32_CONV_NCHW_NCHW44");
  336. }
  337. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_1) {
  338. check_conv_bias(
  339. get_nchw44_channel_wise_args({2, 3}, 1, false, false, false), handle(),
  340. "F32_CHANNEL_WISE_NCHW44");
  341. }
  342. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_2) {
  343. check_conv_bias(
  344. get_nchw44_channel_wise_args({5}, 1, false, false, false), handle(),
  345. "F32_CHANNEL_WISE_NCHW44");
  346. }
  347. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP32_NCHW44) {
  348. check_conv_bias(
  349. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, false), handle(),
  350. "F32_CHANNEL_WISE_NCHW44");
  351. }
  352. /**********************************F16 direct************************/
  353. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  354. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16) {
  355. NormalRNG rng(1);
  356. checker_conv_bias_f16(
  357. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), handle(),
  358. rng, "F16DIRECT", 0.03);
  359. }
  360. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1) {
  361. NormalRNG rng(1);
  362. checker_conv_bias_f16(
  363. get_conv_bias_args({2, 3, 5}, 1, false, false, false), handle(), rng,
  364. "F16STRD1", 0.03);
  365. }
  366. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP16_NCHW88_1) {
  367. NormalRNG rng(1);
  368. checker_conv_bias_f16(
  369. get_nchw88_channel_wise_args({2, 3}, 1, false, false, false), handle(), rng,
  370. "F16_CHANNEL_WISE_NCHW88", 0.03);
  371. }
  372. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP16_NCHW88_2) {
  373. NormalRNG rng(1);
  374. checker_conv_bias_f16(
  375. get_nchw88_channel_wise_args({5}, 1, false, false, false), handle(), rng,
  376. "F16_CHANNEL_WISE_NCHW88", 0.03);
  377. }
  378. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP16_NCHW88) {
  379. NormalRNG rng(1);
  380. checker_conv_bias_f16(
  381. get_nchw88_channel_wise_args({2, 3, 5}, 2, false, false, false), handle(),
  382. rng, "F16_CHANNEL_WISE_NCHW88", 0.03);
  383. }
  384. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_NCHW88_S1) {
  385. NormalRNG rng(1);
  386. checker_conv_bias_f16(
  387. get_nchw88_conv_bias_args({1, 2, 3, 5, 7}, FULL_NLMODE, ALL_BIASMODE, 1),
  388. handle(), rng, "F16_CONV_NCHW88_DIRECT", 0.03);
  389. }
  390. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_NCHW88_S2) {
  391. NormalRNG rng(1);
  392. checker_conv_bias_f16(
  393. get_nchw88_conv_bias_args({1, 2, 3, 5, 7}, FULL_NLMODE, ALL_BIASMODE, 2),
  394. handle(), rng, "F16_CONV_NCHW88_DIRECT", 0.03);
  395. }
  396. #endif
  397. /**********************************algo 8816 direct************************/
  398. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT) {
  399. checker_conv_bias_int8x8x16(
  400. get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(),
  401. "I8816DIRECT");
  402. }
  403. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2) {
  404. checker_conv_bias_int8x8x16(
  405. get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
  406. "I8816STRD2");
  407. }
  408. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_NCHW_NCHW44_S2) {
  409. checker_conv_bias_int8x8x16(
  410. get_nchw44_conv_bias_args(
  411. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_NO_BIASMODE, 2, false,
  412. true),
  413. handle(), "I8816_CONV_NCHW_NCHW44");
  414. }
  415. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_NCHW_NCHW44_S1) {
  416. checker_conv_bias_int8x8x16(
  417. get_nchw44_conv_bias_args(
  418. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_NO_BIASMODE, 1, false,
  419. true),
  420. handle(), "I8816_CONV_NCHW_NCHW44");
  421. }
  422. /**********************************algo 8-8-32 direct************************/
  423. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1) {
  424. checker_conv_bias_int8x8x32_multi(
  425. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  426. "S8STRD1");
  427. }
  428. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2) {
  429. checker_conv_bias_int8x8x32_multi(
  430. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  431. "S8STRD2");
  432. }
  433. TEST_F(ARM_COMMON_MULTI_THREADS,
  434. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT1_NCHW44) {
  435. checker_conv_bias_int8x8x32_multi(
  436. get_nchw44_channel_wise_args({2, 3, 5}, 1, false, true, true), handle(),
  437. "S8_CHAN_WISE_STRD1_NCHW44");
  438. }
  439. TEST_F(ARM_COMMON_MULTI_THREADS,
  440. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT2_NCHW44) {
  441. checker_conv_bias_int8x8x32_multi(
  442. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, true, true), handle(),
  443. "S8_CHAN_WISE_STRD2_NCHW44");
  444. }
  445. TEST_F(ARM_COMMON, CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT1_NCHW44) {
  446. Checker<ConvBias> checker(handle());
  447. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  448. "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"));
  449. checker.set_dtype(0, dtype::Int8());
  450. checker.set_dtype(1, dtype::Int8());
  451. checker.set_dtype(2, dtype::Int16());
  452. checker.set_dtype(4, dtype::Int16());
  453. auto args = get_nchw44_channel_wise_args({2, 3, 5}, 1, false, true, true);
  454. for (auto&& arg : args) {
  455. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  456. }
  457. }
  458. TEST_F(ARM_COMMON_MULTI_THREADS,
  459. CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT2_NCHW44) {
  460. Checker<ConvBias> checker(handle());
  461. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  462. "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"));
  463. checker.set_dtype(0, dtype::Int8());
  464. checker.set_dtype(1, dtype::Int8());
  465. checker.set_dtype(2, dtype::Int16());
  466. checker.set_dtype(4, dtype::Int16());
  467. auto args = get_nchw44_channel_wise_args({2, 3, 5}, 2, false, true, true);
  468. for (auto&& arg : args) {
  469. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  470. }
  471. }
  472. /********************************qint8 direct******************************/
  473. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1) {
  474. checker_conv_bias_qint8x8x8(
  475. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  476. handle(), "S8STRD1");
  477. }
  478. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2) {
  479. checker_conv_bias_qint8x8x8(
  480. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  481. handle(), "S8STRD2");
  482. }
  483. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) {
  484. checker_conv_bias_qint8x8x8(
  485. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE, ONLY_BR_BIASMODE, 1),
  486. handle(), "S8_NCHW44_DIRECT");
  487. }
  488. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8816) {
  489. checker_conv_bias_int8x8x16(
  490. get_nchw44_conv_bias_args(
  491. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 1),
  492. handle(), "S8x8x16_NCHW44_DIRECT");
  493. }
  494. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8816) {
  495. checker_conv_bias_int8x8x16(
  496. get_nchw44_conv_bias_args(
  497. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 2),
  498. handle(), "S8x8x16_NCHW44_DIRECT");
  499. }
  500. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8832) {
  501. checker_conv_bias_qint8x8x32(
  502. get_nchw44_conv_bias_args(
  503. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 1),
  504. handle(), "S8_NCHW44_DIRECT");
  505. }
  506. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8832) {
  507. checker_conv_bias_qint8x8x32(
  508. get_nchw44_conv_bias_args(
  509. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_NO_BIASMODE, 2),
  510. handle(), "S8_NCHW44_DIRECT");
  511. }
  512. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) {
  513. checker_conv_bias_qint8x8x8(
  514. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2),
  515. handle(), "S8_NCHW44_DIRECT");
  516. }
  517. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) {
  518. checker_conv_bias_qint8x8x8(
  519. get_nchw44_conv_bias_args(
  520. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1),
  521. handle(), "S8_NCHW44_DIRECT");
  522. checker_conv_bias_qint8x8x8(
  523. get_nchw44_channel_wise_args({2, 3, 5}, 1, false, false, true), handle(),
  524. "S8_CHAN_WISE_STRD1_NCHW44");
  525. }
  526. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) {
  527. checker_conv_bias_qint8x8x8(
  528. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, true), handle(),
  529. "S8_CHAN_WISE_STRD2_NCHW44");
  530. }
  531. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S1) {
  532. checker_conv_bias_qint8x8x8(
  533. get_nchw44_conv_bias_args(
  534. {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1, false, true),
  535. handle(), "S8_CONV_NCHW_NCHW44");
  536. }
  537. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S2) {
  538. checker_conv_bias_qint8x8x8(
  539. get_nchw44_conv_bias_args(
  540. {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2, false, true),
  541. handle(), "S8_CONV_NCHW_NCHW44");
  542. }
  543. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S1_F1) {
  544. checker_conv_bias_qint8x8x8(
  545. get_nchw44_conv_bias_args(
  546. {1}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1, false, true),
  547. handle(), "S8_CONV_NCHW_NCHW44");
  548. }
  549. /*****************************quint8 direct****************************/
  550. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1) {
  551. checker_conv_bias_quint8x8x8(
  552. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  553. handle(), "QU8STRD1");
  554. }
  555. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) {
  556. checker_conv_bias_quint8x8x8(
  557. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  558. handle(), "QU8STRD2");
  559. }
  560. /****************************dot qint8 direct*************************/
  561. #if MGB_ENABLE_DOT
  562. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) {
  563. auto args = get_nchw44_conv_bias_args(
  564. {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2, false, true);
  565. for (auto&& arg : args) {
  566. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  567. }
  568. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  569. args = get_nchw44_conv_bias_args(
  570. {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1, false, true);
  571. for (auto&& arg : args) {
  572. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  573. }
  574. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  575. }
  576. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44_S1_F1) {
  577. auto args = get_nchw44_conv_bias_args(
  578. {1}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1, false, true);
  579. for (auto&& arg : args) {
  580. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  581. }
  582. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  583. }
  584. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) {
  585. checker_conv_bias_qint8x8x8(
  586. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  587. handle(), "ARMDOTS8STRD1");
  588. }
  589. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_WITHDOTPROD) {
  590. checker_conv_bias_qint8x8x8(
  591. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  592. handle(), "ARMDOTS8STRD2");
  593. }
  594. /****************************dot 8-8-32 direct*************************/
  595. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT) {
  596. checker_conv_bias_qint8x8x32(
  597. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  598. "ARMDOTS8STRD1");
  599. }
  600. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT) {
  601. checker_conv_bias_qint8x8x32(
  602. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  603. "ARMDOTS8STRD2");
  604. }
  605. /******************************dot quint8*****************************/
  606. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD) {
  607. checker_conv_bias_quint8x8x8(
  608. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  609. handle(), "ARMDOTU8STRD1");
  610. }
  611. //! TODO: this test without test kernel size=3, add it will case buss error now
  612. //! in armv7
  613. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD) {
  614. checker_conv_bias_quint8x8x8(
  615. get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false), handle(),
  616. "ARMDOTU8STRD2");
  617. }
  618. /******************************dot quint8x8x32***********************/
  619. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1) {
  620. checker_conv_bias_quint8x8x32(
  621. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  622. "ARMDOTU8STRD1");
  623. }
  624. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2) {
  625. checker_conv_bias_quint8x8x32(
  626. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  627. "ARMDOTU8STRD2");
  628. }
  629. /******************************dot int8x8x8 nchw44 ***********************/
  630. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x8) {
  631. using namespace conv_bias;
  632. std::vector<TestArg> args =
  633. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE, ONLY_BR_BIASMODE, 1);
  634. for (auto&& arg : args)
  635. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  636. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  637. }
  638. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x32) {
  639. using namespace conv_bias;
  640. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  641. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1);
  642. for (auto&& arg : args)
  643. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  644. checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  645. }
  646. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_8x8x32) {
  647. using namespace conv_bias;
  648. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  649. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1);
  650. for (auto&& arg : args)
  651. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  652. checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  653. }
  654. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x8) {
  655. using namespace conv_bias;
  656. //! test qint8x8x8
  657. std::vector<TestArg> args =
  658. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2);
  659. for (auto&& arg : args)
  660. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  661. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  662. }
  663. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x32) {
  664. using namespace conv_bias;
  665. //! test qint8x8x8
  666. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  667. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 2);
  668. for (auto&& arg : args)
  669. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  670. checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  671. }
  672. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_8x8x32) {
  673. using namespace conv_bias;
  674. //! test qint8x8x8
  675. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  676. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 2);
  677. for (auto&& arg : args)
  678. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  679. checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  680. }
  681. #endif
  682. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) {
  683. using namespace conv_bias;
  684. std::vector<TestArg> args = get_winograd_args(3);
  685. Checker<ConvBiasForward> checker(handle());
  686. auto run = [&checker](
  687. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  688. DType C_dtype, DType D_dtype, const float eps) {
  689. for (auto&& arg : args) {
  690. checker.set_dtype(0, A_dtype)
  691. .set_dtype(1, B_dtype)
  692. .set_dtype(2, C_dtype)
  693. .set_dtype(4, D_dtype)
  694. .set_epsilon(eps)
  695. .set_param(arg.param)
  696. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  697. }
  698. };
  699. run(args, dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32(),
  700. 1e-3f);
  701. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  702. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  703. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  704. run(args, dtype::Float16(), dtype::Float16(), dtype::Float16(), dtype::Float16(),
  705. 0.35f);
  706. #endif
  707. }
  708. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4) {
  709. using namespace conv_bias;
  710. std::vector<TestArg> args = get_winograd_mk_packed_args();
  711. Checker<ConvBiasForward> checker(handle());
  712. check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4);
  713. }
  714. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4_NCHW44) {
  715. using namespace conv_bias;
  716. std::vector<TestArg> args =
  717. get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1);
  718. Checker<ConvBiasForward> checker(handle());
  719. check_winograd(
  720. "4:2:32", checker, args, param::MatrixMul::Format::MK4,
  721. param::ConvBias::Format::NCHW44);
  722. }
  723. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63) {
  724. using namespace conv_bias;
  725. std::vector<TestArg> args = get_winograd_args(3);
  726. Checker<ConvBiasForward> checker(handle());
  727. check_winograd("1:6:32", checker, args);
  728. }
  729. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) {
  730. using namespace conv_bias;
  731. std::vector<TestArg> args = get_winograd_mk_packed_args();
  732. Checker<ConvBiasForward> checker(handle());
  733. check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4);
  734. }
  735. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44) {
  736. using namespace conv_bias;
  737. std::vector<TestArg> args =
  738. get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1);
  739. Checker<ConvBiasForward> checker(handle());
  740. check_winograd(
  741. "4:6:16", checker, args, param::MatrixMul::Format::MK4,
  742. param::ConvBias::Format::NCHW44);
  743. }
  744. //! uncomment it when low precision mode is ok
  745. #if 0
  746. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) {
  747. using namespace conv_bias;
  748. std::vector<TestArg> args =
  749. get_nchw44_conv_bias_args({3},QUAN_NLMODE,BR_AND_NO_BIASMODE,1);
  750. Checker<ConvBiasForward> checker(handle());
  751. check_winograd("4:7:16", checker, args, param::MatrixMul::Format::MK4,
  752. param::ConvBias::Format::NCHW44);
  753. }
  754. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44_WEIGHT_PREPROCESS) {
  755. using namespace conv_bias;
  756. std::vector<TestArg> args =
  757. get_nchw44_conv_bias_args({3},QUAN_NLMODE,BR_AND_NO_BIASMODE,1);
  758. Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
  759. handle());
  760. check_winograd("4:7:16", checker, args, param::MatrixMul::Format::MK4,
  761. param::ConvBias::Format::NCHW44);
  762. }
  763. #endif
  764. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54) {
  765. using namespace conv_bias;
  766. std::vector<TestArg> args = get_winograd_args(4);
  767. Checker<ConvBiasForward> checker(handle());
  768. check_winograd("1:5:32", checker, args);
  769. }
  770. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45) {
  771. using namespace conv_bias;
  772. std::vector<TestArg> args = get_winograd_args(5);
  773. Checker<ConvBiasForward> checker(handle());
  774. check_winograd("1:4:32", checker, args);
  775. }
  776. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1) {
  777. using namespace conv_bias;
  778. Checker<ConvBiasForward> checker(handle());
  779. auto run = [&checker](
  780. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  781. DType C_dtype, DType D_dtype, float eps) {
  782. for (auto&& arg : args) {
  783. checker.set_dtype(0, A_dtype)
  784. .set_dtype(1, B_dtype)
  785. .set_dtype(2, C_dtype)
  786. .set_dtype(4, D_dtype)
  787. .set_epsilon(eps)
  788. .set_param(arg.param)
  789. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  790. }
  791. };
  792. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  793. std::vector<TestArg> args_first_half(args.begin(), args.begin() + args.size() / 2);
  794. run(args_first_half, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
  795. dtype::Float32{}, 1e-3f);
  796. }
  797. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2) {
  798. using namespace conv_bias;
  799. Checker<ConvBiasForward> checker(handle());
  800. auto run = [&checker](
  801. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  802. DType C_dtype, DType D_dtype, float eps) {
  803. for (auto&& arg : args) {
  804. checker.set_dtype(0, A_dtype)
  805. .set_dtype(1, B_dtype)
  806. .set_dtype(2, C_dtype)
  807. .set_dtype(4, D_dtype)
  808. .set_epsilon(eps)
  809. .set_param(arg.param)
  810. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  811. }
  812. };
  813. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  814. std::vector<TestArg> args_second_half(args.begin() + args.size() / 2, args.end());
  815. run(args_second_half, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
  816. dtype::Float32{}, 1e-3f);
  817. }
  818. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  819. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F16) {
  820. using namespace conv_bias;
  821. Checker<ConvBiasForward> checker(handle());
  822. auto run = [&checker](
  823. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  824. DType C_dtype, DType D_dtype, float eps) {
  825. for (auto&& arg : args) {
  826. checker.set_dtype(0, A_dtype)
  827. .set_dtype(1, B_dtype)
  828. .set_dtype(2, C_dtype)
  829. .set_dtype(4, D_dtype)
  830. .set_epsilon(eps)
  831. .set_param(arg.param)
  832. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  833. }
  834. };
  835. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  836. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  837. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  838. run(args, dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
  839. 0.25);
  840. }
  841. #endif
  842. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_INT8) {
  843. using namespace conv_bias;
  844. Checker<ConvBiasForward> checker(handle());
  845. auto run = [&checker](
  846. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  847. DType C_dtype, DType D_dtype, float eps) {
  848. for (auto&& arg : args) {
  849. checker.set_dtype(0, A_dtype)
  850. .set_dtype(1, B_dtype)
  851. .set_dtype(2, C_dtype)
  852. .set_dtype(4, D_dtype)
  853. .set_epsilon(eps)
  854. .set_param(arg.param)
  855. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  856. }
  857. };
  858. #if MEGDNN_AARCH64
  859. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  860. #else
  861. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  862. #endif
  863. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  864. ssprintf("WINOGRAD:%s:8:2:32", matmul_name).c_str()));
  865. std::vector<TestArg> quantized_args = get_quantized_winograd_mk_packed_args(8);
  866. UniformIntRNG int_rng{-50, 50};
  867. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  868. run(quantized_args, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
  869. dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), 1e-3);
  870. }
  871. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) {
  872. using namespace conv_bias;
  873. Checker<ConvBiasForward> checker(handle());
  874. auto run = [&checker](
  875. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  876. DType C_dtype, DType D_dtype, float eps) {
  877. for (auto&& arg : args) {
  878. checker.set_dtype(0, A_dtype)
  879. .set_dtype(1, B_dtype)
  880. .set_dtype(2, C_dtype)
  881. .set_dtype(4, D_dtype)
  882. .set_epsilon(eps)
  883. .set_param(arg.param)
  884. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  885. }
  886. };
  887. #if MEGDNN_AARCH64
  888. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  889. #else
  890. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  891. #endif
  892. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  893. ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str()));
  894. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4);
  895. UniformIntRNG int_rng{-50, 50};
  896. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  897. run(quantized_args, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
  898. dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), 1e-3);
  899. }
  900. TEST_F(ARM_COMMON_MULTI_THREADS,
  901. CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32_GROUPMODE) {
  902. using namespace conv_bias;
  903. Checker<ConvBiasForward> checker(handle());
  904. auto run = [&checker](
  905. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  906. DType C_dtype, DType D_dtype, float eps) {
  907. for (auto&& arg : args) {
  908. checker.set_dtype(0, A_dtype)
  909. .set_dtype(1, B_dtype)
  910. .set_dtype(2, C_dtype)
  911. .set_dtype(4, D_dtype)
  912. .set_epsilon(eps)
  913. .set_param(arg.param)
  914. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  915. }
  916. };
  917. float epsilon = 0.001;
  918. #if MEGDNN_AARCH64
  919. const char* matmul_name = "AARCH64_F32_MK4_4x16";
  920. #else
  921. const char* matmul_name = "ARMV7_F32_MK4_4x8";
  922. #endif
  923. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  924. ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str()));
  925. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4, true, true);
  926. UniformIntRNG int_rng{-50, 50};
  927. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  928. run(quantized_args, dtype::QuantizedS8(0.41113496f),
  929. dtype::QuantizedS8(0.01887994f), dtype::QuantizedS32(0.41113496f * 0.01887994f),
  930. dtype::QuantizedS8(0.49550694f), epsilon);
  931. }
  932. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPMODE) {
  933. using namespace conv_bias;
  934. Checker<ConvBiasForward> checker(handle());
  935. auto run = [&checker](
  936. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  937. DType C_dtype, DType D_dtype, float eps) {
  938. for (auto&& arg : args) {
  939. checker.set_dtype(0, A_dtype)
  940. .set_dtype(1, B_dtype)
  941. .set_dtype(2, C_dtype)
  942. .set_dtype(4, D_dtype)
  943. .set_epsilon(eps)
  944. .set_param(arg.param)
  945. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  946. }
  947. };
  948. #if MEGDNN_AARCH64
  949. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  950. #else
  951. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  952. #endif
  953. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  954. ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str()));
  955. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4, false, true);
  956. UniformIntRNG int_rng{-50, 50};
  957. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  958. run(quantized_args, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
  959. dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), 1e-3);
  960. }
  961. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32) {
  962. using namespace conv_bias;
  963. Checker<ConvBiasForward> checker(handle());
  964. auto run = [&checker](
  965. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  966. DType C_dtype, DType D_dtype, float eps) {
  967. for (auto&& arg : args) {
  968. checker.set_dtype(0, A_dtype)
  969. .set_dtype(1, B_dtype)
  970. .set_dtype(2, C_dtype)
  971. .set_dtype(4, D_dtype)
  972. .set_epsilon(eps)
  973. .set_param(arg.param)
  974. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  975. }
  976. };
  977. float epsilon = 0.001;
  978. #if MEGDNN_AARCH64
  979. const char* matmul_name = "AARCH64_F32_MK4_4x16";
  980. #else
  981. const char* matmul_name = "ARMV7_F32_MK4_4x8";
  982. #endif
  983. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  984. ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str()));
  985. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4, true);
  986. UniformIntRNG int_rng{-50, 50};
  987. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  988. run(quantized_args, dtype::QuantizedS8(0.41113496f),
  989. dtype::QuantizedS8(0.01887994f), dtype::QuantizedS32(0.41113496f * 0.01887994f),
  990. dtype::QuantizedS8(0.49550694f), epsilon);
  991. }
  992. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  993. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F23) {
  994. using namespace conv_bias;
  995. std::vector<TestArg> args = get_winograd_mk_packed_args();
  996. Checker<ConvBiasForward> checker(handle());
  997. check_winograd_fp16("1:2:32", checker, args, NULL, 0.08);
  998. }
  999. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_1) {
  1000. using namespace conv_bias;
  1001. std::vector<TestArg> args = get_winograd_args(5);
  1002. std::vector<TestArg> args_head_half(args.begin(), args.begin() + args.size() / 2);
  1003. Checker<ConvBiasForward> checker(handle());
  1004. //! fp16 range -1.0 ~ 1.0
  1005. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1006. check_winograd_fp16("1:4:32", checker, args_head_half, rng, 0.25);
  1007. }
  1008. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_2) {
  1009. using namespace conv_bias;
  1010. std::vector<TestArg> args = get_winograd_args(5);
  1011. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2, args.end());
  1012. Checker<ConvBiasForward> checker(handle());
  1013. //! fp16 range -1.0 ~ 1.0
  1014. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1015. check_winograd_fp16("1:4:32", checker, args_back_half, rng, 0.25);
  1016. }
  1017. //! FIXME: This test may be failed if run `ARM_COMMON.CONV_BIAS_WINOGRAD*`, but
  1018. //! it will pass when run single testcase
  1019. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F63) {
  1020. using namespace conv_bias;
  1021. std::vector<TestArg> args = get_winograd_args(3);
  1022. Checker<ConvBiasForward> checker(handle());
  1023. //! fp16 range -1.0 ~ 1.0
  1024. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1025. check_winograd_fp16("1:6:32", checker, args, rng, 0.3);
  1026. }
  1027. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_1) {
  1028. using namespace conv_bias;
  1029. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  1030. std::vector<TestArg> args_head_half(args.begin(), args.begin() + args.size() / 2);
  1031. Checker<ConvBiasForward> checker(handle());
  1032. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1033. check_winograd_fp16(
  1034. "8:2:32", checker, args_head_half, rng, 0.25,
  1035. param::MatrixMul::Format::MK8);
  1036. }
  1037. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_2) {
  1038. using namespace conv_bias;
  1039. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  1040. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2, args.end());
  1041. Checker<ConvBiasForward> checker(handle());
  1042. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1043. check_winograd_fp16(
  1044. "8:2:32", checker, args_back_half, rng, 0.25,
  1045. param::MatrixMul::Format::MK8);
  1046. }
  1047. #endif
  1048. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_INT8_8X8) {
  1049. using namespace conv_bias;
  1050. std::vector<TestArg> args = get_quantized_winograd_mk_packed_args(8);
  1051. Checker<ConvBiasForward> checker(handle());
  1052. UniformIntRNG rng{-50, 50};
  1053. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  1054. .set_dtype(1, dtype::QuantizedS8(2.5f))
  1055. .set_dtype(2, dtype::QuantizedS32(6.25f))
  1056. .set_dtype(4, dtype::QuantizedS8(60.25f))
  1057. .set_rng(0, &rng)
  1058. .set_rng(1, &rng)
  1059. .set_rng(2, &rng);
  1060. check_winograd("8:2:32", checker, args, param::MatrixMul::Format::MK8);
  1061. }
  1062. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_INT8_8X8_WEIGHT_PREPROCESS) {
  1063. using namespace conv_bias;
  1064. std::vector<TestArg> args = get_quantized_winograd_mk_packed_args(8);
  1065. Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
  1066. handle());
  1067. UniformIntRNG rng{-50, 50};
  1068. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  1069. .set_dtype(1, dtype::QuantizedS8(2.5f))
  1070. .set_dtype(2, dtype::QuantizedS32(6.25f))
  1071. .set_dtype(4, dtype::QuantizedS8(60.25f))
  1072. .set_rng(0, &rng)
  1073. .set_rng(1, &rng)
  1074. .set_rng(2, &rng);
  1075. check_winograd("8:2:32", checker, args, param::MatrixMul::Format::MK8);
  1076. }
  1077. // clang-format on
  1078. // vim: syntax=cpp.doxygen