You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bias_multi_thread.cpp 42 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066
  1. /**
  2. * \file dnn/test/arm_common/conv_bias_multi_thread.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/dtype.h"
  13. #include "test/arm_common/fixture.h"
  14. #include "test/common/benchmarker.h"
  15. #include "test/common/conv_bias.h"
  16. #include "test/arm_common/cpuinfo_help.h"
  17. using namespace megdnn;
  18. using namespace test;
  19. using namespace conv_bias;
  20. std::vector<conv_bias::TestArg> get_int8_quint8_conv_bias_args(
  21. std::vector<size_t> kernel, size_t stride, bool no_pad, bool no_bias,
  22. bool no_nonlinemode) {
  23. using namespace conv_bias;
  24. using Param = param::ConvBias;
  25. using NLMode = param::ConvBias::NonlineMode;
  26. std::vector<TestArg> args;
  27. auto pack = [&](size_t n, size_t oc, size_t ic, size_t w, size_t h, size_t kernel,
  28. size_t stride, NLMode nlmode) {
  29. Param param;
  30. param.stride_h = stride;
  31. param.stride_w = stride;
  32. if (!no_pad) {
  33. param.pad_h = kernel / 2;
  34. param.pad_w = kernel / 2;
  35. } else {
  36. param.pad_h = 0;
  37. param.pad_w = 0;
  38. }
  39. param.nonlineMode = nlmode;
  40. args.emplace_back(
  41. param, TensorShape{n, ic, h, w}, TensorShape{oc, ic, kernel, kernel},
  42. TensorShape{});
  43. if (!no_bias) {
  44. args.emplace_back(
  45. param, TensorShape{n, ic, h, w},
  46. TensorShape{oc, ic, kernel, kernel}, TensorShape{1, oc, 1, 1});
  47. }
  48. };
  49. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  50. if (!no_nonlinemode) {
  51. nonlinemode.emplace_back(NLMode::RELU);
  52. nonlinemode.emplace_back(NLMode::H_SWISH);
  53. }
  54. for (size_t n : {1, 2}) {
  55. for (auto nlmode : nonlinemode) {
  56. for (size_t ic : {1, 3, 7}) {
  57. for (size_t oc : {1, 3, 7}) {
  58. for (size_t size : {4, 6, 8, 14, 16, 18}) {
  59. for (size_t kern : kernel) {
  60. pack(n, oc, ic, size, size, kern, stride, nlmode);
  61. }
  62. }
  63. }
  64. }
  65. }
  66. }
  67. return args;
  68. }
  69. std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args(
  70. std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode,
  71. bool no_full_bias) {
  72. using namespace conv_bias;
  73. using Param = param::ConvBias;
  74. using NLMode = param::ConvBias::NonlineMode;
  75. std::vector<TestArg> args;
  76. auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
  77. size_t stride, NLMode nlmode, bool pad) {
  78. Param param;
  79. param.stride_h = stride;
  80. param.stride_w = stride;
  81. if (pad) {
  82. param.pad_h = kernel / 2;
  83. param.pad_w = kernel / 2;
  84. } else {
  85. param.pad_h = 0;
  86. param.pad_w = 0;
  87. }
  88. param.nonlineMode = nlmode;
  89. param.format = param::ConvBias::Format::NCHW44;
  90. param.sparse = param::ConvBias::Sparse::GROUP;
  91. args.emplace_back(
  92. param, TensorShape{n, group, h, w, 4},
  93. TensorShape{group, 1, 1, kernel, kernel, 4}, TensorShape{});
  94. if (!no_bias) {
  95. args.emplace_back(
  96. param, TensorShape{n, group, h, w, 4},
  97. TensorShape{group, 1, 1, kernel, kernel, 4},
  98. TensorShape{1, group, 1, 1, 4});
  99. }
  100. if (!no_full_bias) {
  101. args.emplace_back(
  102. param, TensorShape{n, group, h, w, 4},
  103. TensorShape{group, 1, 1, kernel, kernel, 4},
  104. TensorShape{
  105. n, group, (h + 2 * param.pad_w - kernel) / stride + 1,
  106. (w + 2 * param.pad_w - kernel) / stride + 1, 4});
  107. }
  108. };
  109. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  110. if (!no_nonlinemode) {
  111. nonlinemode.emplace_back(NLMode::RELU);
  112. nonlinemode.emplace_back(NLMode::H_SWISH);
  113. }
  114. for (size_t n : {1, 2}) {
  115. for (auto nlmode : nonlinemode) {
  116. for (bool pad : {true}) {
  117. for (size_t group : {1, 2, 4, 7, 16}) {
  118. for (size_t size : {4, 6, 7, 9, 20}) {
  119. for (size_t kern : kernel) {
  120. pack(n, group, size, size, kern, stride, nlmode, pad);
  121. }
  122. }
  123. }
  124. }
  125. for (bool pad : {false}) {
  126. for (size_t group : {1, 2, 7, 16}) {
  127. for (size_t size : {7, 9, 20}) {
  128. for (size_t kern : kernel) {
  129. pack(n, group, size, size, kern, stride, nlmode, pad);
  130. }
  131. }
  132. }
  133. }
  134. }
  135. }
  136. return args;
  137. }
  138. std::vector<conv_bias::TestArg> get_nchw88_channel_wise_args(
  139. std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode,
  140. bool no_full_bias) {
  141. using namespace conv_bias;
  142. using Param = param::ConvBias;
  143. using NLMode = param::ConvBias::NonlineMode;
  144. std::vector<TestArg> args;
  145. auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
  146. size_t stride, NLMode nlmode, bool pad) {
  147. Param param;
  148. param.stride_h = stride;
  149. param.stride_w = stride;
  150. if (pad) {
  151. param.pad_h = kernel / 2;
  152. param.pad_w = kernel / 2;
  153. } else {
  154. param.pad_h = 0;
  155. param.pad_w = 0;
  156. }
  157. param.nonlineMode = nlmode;
  158. param.format = param::ConvBias::Format::NCHW88;
  159. param.sparse = param::ConvBias::Sparse::GROUP;
  160. args.emplace_back(
  161. param, TensorShape{n, group, h, w, 8},
  162. TensorShape{group, 1, 1, kernel, kernel, 8}, TensorShape{});
  163. if (!no_bias) {
  164. args.emplace_back(
  165. param, TensorShape{n, group, h, w, 8},
  166. TensorShape{group, 1, 1, kernel, kernel, 8},
  167. TensorShape{1, group, 1, 1, 8});
  168. }
  169. if (!no_full_bias) {
  170. args.emplace_back(
  171. param, TensorShape{n, group, h, w, 8},
  172. TensorShape{group, 1, 1, kernel, kernel, 8},
  173. TensorShape{
  174. n, group, (h + 2 * param.pad_w - kernel) / stride + 1,
  175. (w + 2 * param.pad_w - kernel) / stride + 1, 8});
  176. }
  177. };
  178. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  179. if (!no_nonlinemode) {
  180. nonlinemode.emplace_back(NLMode::RELU);
  181. nonlinemode.emplace_back(NLMode::H_SWISH);
  182. }
  183. for (size_t n : {1, 2}) {
  184. for (auto nlmode : nonlinemode) {
  185. for (bool pad : {true}) {
  186. for (size_t group : {1, 2, 4, 7, 8, 16}) {
  187. for (size_t size : {4, 6, 7, 9, 20}) {
  188. for (size_t kern : kernel) {
  189. pack(n, group, size, size, kern, stride, nlmode, pad);
  190. }
  191. }
  192. }
  193. }
  194. for (bool pad : {false}) {
  195. for (size_t group : {1, 2, 7, 16}) {
  196. for (size_t size : {7, 9, 20}) {
  197. for (size_t kern : kernel) {
  198. pack(n, group, size, size, kern, stride, nlmode, pad);
  199. }
  200. }
  201. }
  202. }
  203. }
  204. }
  205. return args;
  206. }
  207. void checker_conv_bias_qint8x8x8(
  208. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name) {
  209. Checker<ConvBias> checker(handle);
  210. checker.set_before_exec_callback(
  211. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  212. #if MEGDNN_ARMV7
  213. checker.set_epsilon(1);
  214. #endif
  215. UniformIntRNG rng{-50, 50};
  216. checker.set_dtype(0, dtype::QuantizedS8(0.41113496f))
  217. .set_dtype(1, dtype::QuantizedS8(0.01887994f))
  218. .set_dtype(2, dtype::QuantizedS32(0.41113496f * 0.01887994f))
  219. .set_dtype(4, dtype::QuantizedS8(0.49550694f))
  220. .set_rng(0, &rng)
  221. .set_rng(1, &rng)
  222. .set_rng(2, &rng);
  223. for (auto&& arg : args) {
  224. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  225. }
  226. }
  227. void checker_conv_bias_qint8x8x32(
  228. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name) {
  229. Checker<ConvBias> checker(handle);
  230. UniformIntRNG rng{-50, 50};
  231. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  232. .set_dtype(1, dtype::QuantizedS8(2.5f))
  233. .set_dtype(2, dtype::QuantizedS32(6.25f))
  234. .set_dtype(4, {});
  235. checker.set_before_exec_callback(
  236. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  237. for (auto&& arg : args) {
  238. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  239. }
  240. }
  241. void checker_conv_bias_quint8x8x8(
  242. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name) {
  243. Checker<ConvBias> checker(handle);
  244. checker.set_before_exec_callback(
  245. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  246. UniformIntRNG rng(0, 255);
  247. checker.set_dtype(0, dtype::Quantized8Asymm(0.2f, 100))
  248. .set_dtype(1, dtype::Quantized8Asymm(0.2f, 120))
  249. .set_dtype(2, dtype::QuantizedS32(0.04f))
  250. .set_dtype(4, dtype::Quantized8Asymm(1.4f, 110))
  251. .set_rng(0, &rng)
  252. .set_rng(1, &rng)
  253. .set_rng(2, &rng);
  254. for (auto&& arg : args) {
  255. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  256. }
  257. }
  258. void checker_conv_bias_quint8x8x32(
  259. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name) {
  260. Checker<ConvBias> checker(handle);
  261. checker.set_before_exec_callback(
  262. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  263. NormalRNG rng(128.f);
  264. checker.set_rng(0, &rng).set_rng(1, &rng);
  265. checker.set_dtype(0, dtype::Quantized8Asymm(1.2f, (uint8_t)127))
  266. .set_dtype(1, dtype::Quantized8Asymm(1.3f, (uint8_t)129))
  267. .set_dtype(2, dtype::QuantizedS32(1.2 * 1.3))
  268. .set_dtype(4, {});
  269. for (auto&& arg : args) {
  270. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  271. }
  272. }
  273. void checker_conv_bias_int8x8x32_multi(
  274. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name) {
  275. Checker<ConvBias> checker(handle);
  276. checker.set_before_exec_callback(
  277. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  278. checker.set_dtype(0, dtype::Int8());
  279. checker.set_dtype(1, dtype::Int8());
  280. checker.set_dtype(2, dtype::Int32());
  281. checker.set_dtype(4, dtype::Int32());
  282. for (auto&& arg : args) {
  283. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  284. }
  285. }
  286. /**********************************F16 direct************************/
  287. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  288. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16) {
  289. NormalRNG rng(1);
  290. checker_conv_bias_f16(
  291. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), handle(),
  292. rng, "F16DIRECT", 0.03);
  293. }
  294. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1) {
  295. NormalRNG rng(1);
  296. checker_conv_bias_f16(
  297. get_conv_bias_args({2, 3, 5}, 1, false, false, false), handle(), rng,
  298. "F16STRD1", 0.03);
  299. }
  300. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP16_NCHW88_1) {
  301. NormalRNG rng(1);
  302. checker_conv_bias_f16(
  303. get_nchw88_channel_wise_args({2, 3}, 1, false, false, false), handle(), rng,
  304. "F16_CHANNEL_WISE_NCHW88", 0.03);
  305. }
  306. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP16_NCHW88_2) {
  307. NormalRNG rng(1);
  308. checker_conv_bias_f16(
  309. get_nchw88_channel_wise_args({5}, 1, false, false, false), handle(), rng,
  310. "F16_CHANNEL_WISE_NCHW88", 0.03);
  311. }
  312. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP16_NCHW88) {
  313. NormalRNG rng(1);
  314. checker_conv_bias_f16(
  315. get_nchw88_channel_wise_args({2, 3, 5}, 2, false, false, false), handle(),
  316. rng, "F16_CHANNEL_WISE_NCHW88", 0.03);
  317. }
  318. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_NCHW88_S1) {
  319. NormalRNG rng(1);
  320. checker_conv_bias_f16(
  321. get_nchw88_conv_bias_args({1, 2, 3, 5, 7}, FULL_NLMODE, ALL_BIASMODE, 1),
  322. handle(), rng, "F16_CONV_NCHW88_DIRECT", 0.03);
  323. }
  324. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_NCHW88_S2) {
  325. NormalRNG rng(1);
  326. checker_conv_bias_f16(
  327. get_nchw88_conv_bias_args({1, 2, 3, 5, 7}, FULL_NLMODE, ALL_BIASMODE, 2),
  328. handle(), rng, "F16_CONV_NCHW88_DIRECT", 0.03);
  329. }
  330. #endif
  331. /**********************************algo 8816 direct************************/
  332. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT) {
  333. checker_conv_bias_int8x8x16(
  334. get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(),
  335. "I8816DIRECT");
  336. }
  337. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2) {
  338. checker_conv_bias_int8x8x16(
  339. get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
  340. "I8816STRD2");
  341. }
  342. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_NCHW_NCHW44_S2) {
  343. checker_conv_bias_int8x8x16(
  344. get_nchw44_conv_bias_args(
  345. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_NO_BIASMODE, 2, false,
  346. true),
  347. handle(), "I8816_CONV_NCHW_NCHW44");
  348. }
  349. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_NCHW_NCHW44_S1) {
  350. checker_conv_bias_int8x8x16(
  351. get_nchw44_conv_bias_args(
  352. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_NO_BIASMODE, 1, false,
  353. true),
  354. handle(), "I8816_CONV_NCHW_NCHW44");
  355. }
  356. /**********************************algo 8-8-32 direct************************/
  357. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1) {
  358. checker_conv_bias_int8x8x32_multi(
  359. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  360. "S8STRD1");
  361. }
  362. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2) {
  363. checker_conv_bias_int8x8x32_multi(
  364. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  365. "S8STRD2");
  366. }
  367. TEST_F(ARM_COMMON_MULTI_THREADS,
  368. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT1_NCHW44) {
  369. checker_conv_bias_int8x8x32_multi(
  370. get_nchw44_channel_wise_args({2, 3, 5}, 1, false, true, true), handle(),
  371. "S8_CHAN_WISE_STRD1_NCHW44");
  372. }
  373. TEST_F(ARM_COMMON_MULTI_THREADS,
  374. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT2_NCHW44) {
  375. checker_conv_bias_int8x8x32_multi(
  376. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, true, true), handle(),
  377. "S8_CHAN_WISE_STRD2_NCHW44");
  378. }
  379. TEST_F(ARM_COMMON, CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT1_NCHW44) {
  380. Checker<ConvBias> checker(handle());
  381. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  382. "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"));
  383. checker.set_dtype(0, dtype::Int8());
  384. checker.set_dtype(1, dtype::Int8());
  385. checker.set_dtype(2, dtype::Int16());
  386. checker.set_dtype(4, dtype::Int16());
  387. auto args = get_nchw44_channel_wise_args({2, 3, 5}, 1, false, true, true);
  388. for (auto&& arg : args) {
  389. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  390. }
  391. }
  392. TEST_F(ARM_COMMON_MULTI_THREADS,
  393. CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT2_NCHW44) {
  394. Checker<ConvBias> checker(handle());
  395. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  396. "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"));
  397. checker.set_dtype(0, dtype::Int8());
  398. checker.set_dtype(1, dtype::Int8());
  399. checker.set_dtype(2, dtype::Int16());
  400. checker.set_dtype(4, dtype::Int16());
  401. auto args = get_nchw44_channel_wise_args({2, 3, 5}, 2, false, true, true);
  402. for (auto&& arg : args) {
  403. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  404. }
  405. }
  406. /********************************qint8 direct******************************/
  407. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1) {
  408. checker_conv_bias_qint8x8x8(
  409. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  410. handle(), "S8STRD1");
  411. }
  412. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2) {
  413. checker_conv_bias_qint8x8x8(
  414. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  415. handle(), "S8STRD2");
  416. }
  417. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) {
  418. checker_conv_bias_qint8x8x8(
  419. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE, ONLY_BR_BIASMODE, 1),
  420. handle(), "S8_NCHW44_DIRECT");
  421. }
  422. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8816) {
  423. checker_conv_bias_int8x8x16(
  424. get_nchw44_conv_bias_args(
  425. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 1),
  426. handle(), "S8x8x16_NCHW44_DIRECT");
  427. }
  428. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8816) {
  429. checker_conv_bias_int8x8x16(
  430. get_nchw44_conv_bias_args(
  431. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 2),
  432. handle(), "S8x8x16_NCHW44_DIRECT");
  433. }
  434. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8832) {
  435. checker_conv_bias_qint8x8x32(
  436. get_nchw44_conv_bias_args(
  437. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 1),
  438. handle(), "S8_NCHW44_DIRECT");
  439. }
  440. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8832) {
  441. checker_conv_bias_qint8x8x32(
  442. get_nchw44_conv_bias_args(
  443. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_NO_BIASMODE, 2),
  444. handle(), "S8_NCHW44_DIRECT");
  445. }
  446. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) {
  447. checker_conv_bias_qint8x8x8(
  448. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2),
  449. handle(), "S8_NCHW44_DIRECT");
  450. }
  451. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) {
  452. checker_conv_bias_qint8x8x8(
  453. get_nchw44_conv_bias_args(
  454. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1),
  455. handle(), "S8_NCHW44_DIRECT");
  456. checker_conv_bias_qint8x8x8(
  457. get_nchw44_channel_wise_args({2, 3, 5}, 1, false, false, true), handle(),
  458. "S8_CHAN_WISE_STRD1_NCHW44");
  459. }
  460. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) {
  461. checker_conv_bias_qint8x8x8(
  462. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, true), handle(),
  463. "S8_CHAN_WISE_STRD2_NCHW44");
  464. }
  465. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S1) {
  466. checker_conv_bias_qint8x8x8(
  467. get_nchw44_conv_bias_args(
  468. {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1, false, true),
  469. handle(), "S8_CONV_NCHW_NCHW44");
  470. }
  471. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S2) {
  472. checker_conv_bias_qint8x8x8(
  473. get_nchw44_conv_bias_args(
  474. {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2, false, true),
  475. handle(), "S8_CONV_NCHW_NCHW44");
  476. }
  477. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S1_F1) {
  478. checker_conv_bias_qint8x8x8(
  479. get_nchw44_conv_bias_args(
  480. {1}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1, false, true),
  481. handle(), "S8_CONV_NCHW_NCHW44");
  482. }
  483. /*****************************quint8 direct****************************/
  484. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1) {
  485. checker_conv_bias_quint8x8x8(
  486. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  487. handle(), "QU8STRD1");
  488. }
  489. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) {
  490. checker_conv_bias_quint8x8x8(
  491. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  492. handle(), "QU8STRD2");
  493. }
  494. /****************************dot qint8 direct*************************/
  495. #if MGB_ENABLE_DOT
  496. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) {
  497. auto args = get_nchw44_conv_bias_args(
  498. {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2, false, true);
  499. for (auto&& arg : args) {
  500. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  501. }
  502. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  503. args = get_nchw44_conv_bias_args(
  504. {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1, false, true);
  505. for (auto&& arg : args) {
  506. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  507. }
  508. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  509. }
  510. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44_S1_F1) {
  511. auto args = get_nchw44_conv_bias_args(
  512. {1}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1, false, true);
  513. for (auto&& arg : args) {
  514. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  515. }
  516. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  517. }
  518. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) {
  519. checker_conv_bias_qint8x8x8(
  520. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  521. handle(), "ARMDOTS8STRD1");
  522. }
  523. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_WITHDOTPROD) {
  524. checker_conv_bias_qint8x8x8(
  525. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  526. handle(), "ARMDOTS8STRD2");
  527. }
  528. /****************************dot 8-8-32 direct*************************/
  529. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT) {
  530. checker_conv_bias_qint8x8x32(
  531. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  532. "ARMDOTS8STRD1");
  533. }
  534. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT) {
  535. checker_conv_bias_qint8x8x32(
  536. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  537. "ARMDOTS8STRD2");
  538. }
  539. /******************************dot quint8*****************************/
  540. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD) {
  541. checker_conv_bias_quint8x8x8(
  542. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  543. handle(), "ARMDOTU8STRD1");
  544. }
  545. //! TODO: this test without test kernel size=3, add it will case buss error now
  546. //! in armv7
  547. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD) {
  548. checker_conv_bias_quint8x8x8(
  549. get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false), handle(),
  550. "ARMDOTU8STRD2");
  551. }
  552. /******************************dot quint8x8x32***********************/
  553. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1) {
  554. checker_conv_bias_quint8x8x32(
  555. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  556. "ARMDOTU8STRD1");
  557. }
  558. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2) {
  559. checker_conv_bias_quint8x8x32(
  560. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  561. "ARMDOTU8STRD2");
  562. }
  563. /******************************dot int8x8x8 nchw44 ***********************/
  564. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x8) {
  565. using namespace conv_bias;
  566. std::vector<TestArg> args =
  567. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE, ONLY_BR_BIASMODE, 1);
  568. for (auto&& arg : args)
  569. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  570. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  571. }
  572. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x32) {
  573. using namespace conv_bias;
  574. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  575. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1);
  576. for (auto&& arg : args)
  577. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  578. checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  579. }
  580. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_8x8x32) {
  581. using namespace conv_bias;
  582. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  583. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1);
  584. for (auto&& arg : args)
  585. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  586. checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  587. }
  588. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x8) {
  589. using namespace conv_bias;
  590. //! test qint8x8x8
  591. std::vector<TestArg> args =
  592. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2);
  593. for (auto&& arg : args)
  594. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  595. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  596. }
  597. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x32) {
  598. using namespace conv_bias;
  599. //! test qint8x8x8
  600. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  601. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 2);
  602. for (auto&& arg : args)
  603. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  604. checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  605. }
  606. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_8x8x32) {
  607. using namespace conv_bias;
  608. //! test qint8x8x8
  609. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  610. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 2);
  611. for (auto&& arg : args)
  612. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  613. checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  614. }
  615. #endif
  616. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) {
  617. using namespace conv_bias;
  618. std::vector<TestArg> args = get_winograd_args(3);
  619. Checker<ConvBiasForward> checker(handle());
  620. auto run = [&checker](
  621. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  622. DType C_dtype, DType D_dtype, const float eps) {
  623. for (auto&& arg : args) {
  624. checker.set_dtype(0, A_dtype)
  625. .set_dtype(1, B_dtype)
  626. .set_dtype(2, C_dtype)
  627. .set_dtype(4, D_dtype)
  628. .set_epsilon(eps)
  629. .set_param(arg.param)
  630. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  631. }
  632. };
  633. run(args, dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32(),
  634. 1e-3f);
  635. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  636. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  637. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  638. run(args, dtype::Float16(), dtype::Float16(), dtype::Float16(), dtype::Float16(),
  639. 0.35f);
  640. #endif
  641. }
  642. //! uncomment it when low precision mode is ok
  643. #if 0
  644. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) {
  645. using namespace conv_bias;
  646. std::vector<TestArg> args =
  647. get_nchw44_conv_bias_args({3},QUAN_NLMODE,BR_AND_NO_BIASMODE,1);
  648. Checker<ConvBiasForward> checker(handle());
  649. check_winograd("4:7:16", checker, args, param::MatrixMul::Format::MK4,
  650. param::ConvBias::Format::NCHW44);
  651. }
  652. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44_WEIGHT_PREPROCESS) {
  653. using namespace conv_bias;
  654. std::vector<TestArg> args =
  655. get_nchw44_conv_bias_args({3},QUAN_NLMODE,BR_AND_NO_BIASMODE,1);
  656. Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
  657. handle());
  658. check_winograd("4:7:16", checker, args, param::MatrixMul::Format::MK4,
  659. param::ConvBias::Format::NCHW44);
  660. }
  661. #endif
  662. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1) {
  663. using namespace conv_bias;
  664. Checker<ConvBiasForward> checker(handle());
  665. auto run = [&checker](
  666. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  667. DType C_dtype, DType D_dtype, float eps) {
  668. for (auto&& arg : args) {
  669. checker.set_dtype(0, A_dtype)
  670. .set_dtype(1, B_dtype)
  671. .set_dtype(2, C_dtype)
  672. .set_dtype(4, D_dtype)
  673. .set_epsilon(eps)
  674. .set_param(arg.param)
  675. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  676. }
  677. };
  678. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  679. std::vector<TestArg> args_first_half(args.begin(), args.begin() + args.size() / 2);
  680. run(args_first_half, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
  681. dtype::Float32{}, 1e-3f);
  682. }
  683. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2) {
  684. using namespace conv_bias;
  685. Checker<ConvBiasForward> checker(handle());
  686. auto run = [&checker](
  687. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  688. DType C_dtype, DType D_dtype, float eps) {
  689. for (auto&& arg : args) {
  690. checker.set_dtype(0, A_dtype)
  691. .set_dtype(1, B_dtype)
  692. .set_dtype(2, C_dtype)
  693. .set_dtype(4, D_dtype)
  694. .set_epsilon(eps)
  695. .set_param(arg.param)
  696. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  697. }
  698. };
  699. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  700. std::vector<TestArg> args_second_half(args.begin() + args.size() / 2, args.end());
  701. run(args_second_half, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
  702. dtype::Float32{}, 1e-3f);
  703. }
  704. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  705. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F16) {
  706. using namespace conv_bias;
  707. Checker<ConvBiasForward> checker(handle());
  708. auto run = [&checker](
  709. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  710. DType C_dtype, DType D_dtype, float eps) {
  711. for (auto&& arg : args) {
  712. checker.set_dtype(0, A_dtype)
  713. .set_dtype(1, B_dtype)
  714. .set_dtype(2, C_dtype)
  715. .set_dtype(4, D_dtype)
  716. .set_epsilon(eps)
  717. .set_param(arg.param)
  718. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  719. }
  720. };
  721. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  722. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  723. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  724. run(args, dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
  725. 0.25);
  726. }
  727. #endif
  728. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_INT8) {
  729. using namespace conv_bias;
  730. Checker<ConvBiasForward> checker(handle());
  731. auto run = [&checker](
  732. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  733. DType C_dtype, DType D_dtype, float eps) {
  734. for (auto&& arg : args) {
  735. checker.set_dtype(0, A_dtype)
  736. .set_dtype(1, B_dtype)
  737. .set_dtype(2, C_dtype)
  738. .set_dtype(4, D_dtype)
  739. .set_epsilon(eps)
  740. .set_param(arg.param)
  741. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  742. }
  743. };
  744. #if MEGDNN_AARCH64
  745. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  746. #else
  747. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  748. #endif
  749. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  750. ssprintf("WINOGRAD:%s:8:2:32", matmul_name).c_str()));
  751. std::vector<TestArg> quantized_args = get_quantized_winograd_mk_packed_args(8);
  752. UniformIntRNG int_rng{-50, 50};
  753. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  754. run(quantized_args, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
  755. dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), 1e-3);
  756. }
  757. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) {
  758. using namespace conv_bias;
  759. Checker<ConvBiasForward> checker(handle());
  760. auto run = [&checker](
  761. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  762. DType C_dtype, DType D_dtype, float eps) {
  763. for (auto&& arg : args) {
  764. checker.set_dtype(0, A_dtype)
  765. .set_dtype(1, B_dtype)
  766. .set_dtype(2, C_dtype)
  767. .set_dtype(4, D_dtype)
  768. .set_epsilon(eps)
  769. .set_param(arg.param)
  770. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  771. }
  772. };
  773. #if MEGDNN_AARCH64
  774. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  775. #else
  776. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  777. #endif
  778. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  779. ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str()));
  780. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4);
  781. UniformIntRNG int_rng{-50, 50};
  782. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  783. run(quantized_args, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
  784. dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), 1e-3);
  785. }
  786. TEST_F(ARM_COMMON_MULTI_THREADS,
  787. CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32_GROUPMODE) {
  788. using namespace conv_bias;
  789. Checker<ConvBiasForward> checker(handle());
  790. auto run = [&checker](
  791. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  792. DType C_dtype, DType D_dtype, float eps) {
  793. for (auto&& arg : args) {
  794. checker.set_dtype(0, A_dtype)
  795. .set_dtype(1, B_dtype)
  796. .set_dtype(2, C_dtype)
  797. .set_dtype(4, D_dtype)
  798. .set_epsilon(eps)
  799. .set_param(arg.param)
  800. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  801. }
  802. };
  803. float epsilon = 0.001;
  804. #if MEGDNN_AARCH64
  805. const char* matmul_name = "AARCH64_F32_MK4_4x16";
  806. #else
  807. const char* matmul_name = "ARMV7_F32_MK4_4x8";
  808. #endif
  809. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  810. ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str()));
  811. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4, true, true);
  812. UniformIntRNG int_rng{-50, 50};
  813. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  814. run(quantized_args, dtype::QuantizedS8(0.41113496f),
  815. dtype::QuantizedS8(0.01887994f), dtype::QuantizedS32(0.41113496f * 0.01887994f),
  816. dtype::QuantizedS8(0.49550694f), epsilon);
  817. }
  818. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPMODE) {
  819. using namespace conv_bias;
  820. Checker<ConvBiasForward> checker(handle());
  821. auto run = [&checker](
  822. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  823. DType C_dtype, DType D_dtype, float eps) {
  824. for (auto&& arg : args) {
  825. checker.set_dtype(0, A_dtype)
  826. .set_dtype(1, B_dtype)
  827. .set_dtype(2, C_dtype)
  828. .set_dtype(4, D_dtype)
  829. .set_epsilon(eps)
  830. .set_param(arg.param)
  831. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  832. }
  833. };
  834. #if MEGDNN_AARCH64
  835. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  836. #else
  837. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  838. #endif
  839. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  840. ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str()));
  841. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4, false, true);
  842. UniformIntRNG int_rng{-50, 50};
  843. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  844. run(quantized_args, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
  845. dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), 1e-3);
  846. }
  847. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32) {
  848. using namespace conv_bias;
  849. Checker<ConvBiasForward> checker(handle());
  850. auto run = [&checker](
  851. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  852. DType C_dtype, DType D_dtype, float eps) {
  853. for (auto&& arg : args) {
  854. checker.set_dtype(0, A_dtype)
  855. .set_dtype(1, B_dtype)
  856. .set_dtype(2, C_dtype)
  857. .set_dtype(4, D_dtype)
  858. .set_epsilon(eps)
  859. .set_param(arg.param)
  860. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  861. }
  862. };
  863. float epsilon = 0.001;
  864. #if MEGDNN_AARCH64
  865. const char* matmul_name = "AARCH64_F32_MK4_4x16";
  866. #else
  867. const char* matmul_name = "ARMV7_F32_MK4_4x8";
  868. #endif
  869. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  870. ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str()));
  871. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4, true);
  872. UniformIntRNG int_rng{-50, 50};
  873. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  874. run(quantized_args, dtype::QuantizedS8(0.41113496f),
  875. dtype::QuantizedS8(0.01887994f), dtype::QuantizedS32(0.41113496f * 0.01887994f),
  876. dtype::QuantizedS8(0.49550694f), epsilon);
  877. }
  878. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  879. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F23) {
  880. using namespace conv_bias;
  881. std::vector<TestArg> args = get_winograd_mk_packed_args();
  882. Checker<ConvBiasForward> checker(handle());
  883. check_winograd_fp16("1:2:32", checker, args, NULL, 0.08);
  884. }
  885. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_1) {
  886. using namespace conv_bias;
  887. std::vector<TestArg> args = get_winograd_args(5);
  888. std::vector<TestArg> args_head_half(args.begin(), args.begin() + args.size() / 2);
  889. Checker<ConvBiasForward> checker(handle());
  890. //! fp16 range -1.0 ~ 1.0
  891. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  892. check_winograd_fp16("1:4:32", checker, args_head_half, rng, 0.25);
  893. }
  894. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_2) {
  895. using namespace conv_bias;
  896. std::vector<TestArg> args = get_winograd_args(5);
  897. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2, args.end());
  898. Checker<ConvBiasForward> checker(handle());
  899. //! fp16 range -1.0 ~ 1.0
  900. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  901. check_winograd_fp16("1:4:32", checker, args_back_half, rng, 0.25);
  902. }
  903. //! FIXME: This test may be failed if run `ARM_COMMON.CONV_BIAS_WINOGRAD*`, but
  904. //! it will pass when run single testcase
  905. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F63) {
  906. using namespace conv_bias;
  907. std::vector<TestArg> args = get_winograd_args(3);
  908. Checker<ConvBiasForward> checker(handle());
  909. //! fp16 range -1.0 ~ 1.0
  910. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  911. check_winograd_fp16("1:6:32", checker, args, rng, 0.3);
  912. }
  913. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_1) {
  914. using namespace conv_bias;
  915. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  916. std::vector<TestArg> args_head_half(args.begin(), args.begin() + args.size() / 2);
  917. Checker<ConvBiasForward> checker(handle());
  918. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  919. check_winograd_fp16(
  920. "8:2:32", checker, args_head_half, rng, 0.25,
  921. param::MatrixMul::Format::MK8);
  922. }
  923. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_2) {
  924. using namespace conv_bias;
  925. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  926. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2, args.end());
  927. Checker<ConvBiasForward> checker(handle());
  928. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  929. check_winograd_fp16(
  930. "8:2:32", checker, args_back_half, rng, 0.25,
  931. param::MatrixMul::Format::MK8);
  932. }
  933. #endif
  934. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_INT8_8X8) {
  935. using namespace conv_bias;
  936. std::vector<TestArg> args = get_quantized_winograd_mk_packed_args(8);
  937. Checker<ConvBiasForward> checker(handle());
  938. UniformIntRNG rng{-50, 50};
  939. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  940. .set_dtype(1, dtype::QuantizedS8(2.5f))
  941. .set_dtype(2, dtype::QuantizedS32(6.25f))
  942. .set_dtype(4, dtype::QuantizedS8(60.25f))
  943. .set_rng(0, &rng)
  944. .set_rng(1, &rng)
  945. .set_rng(2, &rng);
  946. check_winograd("8:2:32", checker, args, param::MatrixMul::Format::MK8);
  947. }
  948. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_INT8_8X8_WEIGHT_PREPROCESS) {
  949. using namespace conv_bias;
  950. std::vector<TestArg> args = get_quantized_winograd_mk_packed_args(8);
  951. Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
  952. handle());
  953. UniformIntRNG rng{-50, 50};
  954. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  955. .set_dtype(1, dtype::QuantizedS8(2.5f))
  956. .set_dtype(2, dtype::QuantizedS32(6.25f))
  957. .set_dtype(4, dtype::QuantizedS8(60.25f))
  958. .set_rng(0, &rng)
  959. .set_rng(1, &rng)
  960. .set_rng(2, &rng);
  961. check_winograd("8:2:32", checker, args, param::MatrixMul::Format::MK8);
  962. }
  963. // clang-format on
  964. // vim: syntax=cpp.doxygen