You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bias_multi_thread.cpp 60 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438
  1. /**
  2. * \file dnn/test/arm_common/conv_bias_multi_thread.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "test/arm_common/fixture.h"
  13. #include "test/common/benchmarker.h"
  14. #include "test/common/conv_bias.h"
  15. using namespace megdnn;
  16. using namespace test;
  17. using namespace conv_bias;
  18. std::vector<conv_bias::TestArg> get_int8_quint8_conv_bias_args(
  19. std::vector<size_t> kernel, size_t stride, bool no_pad, bool no_bias,
  20. bool no_nonlinemode) {
  21. using namespace conv_bias;
  22. using Param = param::ConvBias;
  23. using NLMode = param::ConvBias::NonlineMode;
  24. std::vector<TestArg> args;
  25. auto pack = [&](size_t n, size_t oc, size_t ic, size_t w, size_t h,
  26. size_t kernel, size_t stride, NLMode nlmode) {
  27. Param param;
  28. param.stride_h = stride;
  29. param.stride_w = stride;
  30. if (!no_pad) {
  31. param.pad_h = kernel / 2;
  32. param.pad_w = kernel / 2;
  33. } else {
  34. param.pad_h = 0;
  35. param.pad_w = 0;
  36. }
  37. param.nonlineMode = nlmode;
  38. args.emplace_back(param, TensorShape{n, ic, h, w},
  39. TensorShape{oc, ic, kernel, kernel}, TensorShape{});
  40. if (!no_bias) {
  41. args.emplace_back(param, TensorShape{n, ic, h, w},
  42. TensorShape{oc, ic, kernel, kernel},
  43. TensorShape{1, oc, 1, 1});
  44. }
  45. };
  46. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  47. if (!no_nonlinemode) {
  48. nonlinemode.emplace_back(NLMode::RELU);
  49. nonlinemode.emplace_back(NLMode::H_SWISH);
  50. }
  51. for (size_t n : {1, 2}) {
  52. for (auto nlmode : nonlinemode) {
  53. for (size_t ic : {1, 3, 7}) {
  54. for (size_t oc : {1, 3, 7}) {
  55. for (size_t size : {4, 6, 8, 14, 16, 18}) {
  56. for (size_t kern : kernel) {
  57. pack(n, oc, ic, size, size, kern, stride, nlmode);
  58. }
  59. }
  60. }
  61. }
  62. }
  63. }
  64. return args;
  65. }
  66. std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
  67. std::vector<size_t> kernel_vec, size_t stride, bool no_pad = false,
  68. bool no_bias = false, bool no_nonlinemode = false,
  69. bool is_input_nchw = false) {
  70. using namespace conv_bias;
  71. using NLMode = param::ConvBias::NonlineMode;
  72. std::vector<TestArg> args;
  73. auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w,
  74. size_t kernel, size_t stride, size_t group, NLMode nlmode) {
  75. constexpr int pack_c = 4;
  76. const size_t pad = no_pad ? 0 : kernel / 2;
  77. auto bias_mode = no_bias ? megdnn::BiasMode::NO_BIAS
  78. : megdnn::BiasMode::BROADCAST_CHANNEL_BIAS;
  79. auto oc_per_group = oc / group;
  80. auto ic_per_group = ic / group;
  81. bool ok_group = (oc % group == 0 && ic % group == 0) &&
  82. oc_per_group % pack_c == 0 && oc_per_group > 0 &&
  83. ic_per_group > 0;
  84. bool nchw_disable = group > 1 || ic_per_group >= 4;
  85. bool nchw44_disable = ic_per_group % pack_c != 0;
  86. if (!(ok_group)) {
  87. return;
  88. }
  89. if ((is_input_nchw && nchw_disable) ||
  90. (!is_input_nchw && nchw44_disable)) {
  91. return;
  92. }
  93. size_t kernel_h = kernel;
  94. size_t kernel_w = kernel;
  95. param::ConvBias param;
  96. param.format = param::ConvBias::Format::NCHW44;
  97. param.stride_h = stride;
  98. param.stride_w = stride;
  99. param.pad_h = pad;
  100. param.pad_w = pad;
  101. param.nonlineMode = nlmode;
  102. auto src_tensor_shape = TensorShape{n, ic / pack_c, h, w, pack_c};
  103. auto weight_tensor_shape = TensorShape{
  104. oc / pack_c, ic / pack_c, kernel_h, kernel_w, pack_c, pack_c};
  105. auto bias_tensor_shape = TensorShape{};
  106. if (bias_mode == megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) {
  107. bias_tensor_shape = {1, oc / pack_c, 1, 1, pack_c};
  108. }
  109. if (group == 1) {
  110. param.sparse = param::ConvBias::Sparse::DENSE;
  111. } else if (group > 1 && ic / group == 1 && oc / group == 1) {
  112. megdnn_assert(0, "not support channel wise");
  113. param.sparse = param::ConvBias::Sparse::GROUP;
  114. weight_tensor_shape = TensorShape{group / pack_c, 1, 1,
  115. kernel_h, kernel_w, pack_c};
  116. } else if (group > 1 && oc_per_group % pack_c == 0 && oc / group > 0 &&
  117. ic_per_group % pack_c == 0 && ic / group > 0) {
  118. param.sparse = param::ConvBias::Sparse::GROUP;
  119. weight_tensor_shape = TensorShape{group,
  120. oc_per_group / pack_c,
  121. ic_per_group / pack_c,
  122. kernel_h,
  123. kernel_w,
  124. pack_c,
  125. pack_c};
  126. }
  127. if (is_input_nchw) {
  128. src_tensor_shape = TensorShape{n, ic, h, w};
  129. weight_tensor_shape =
  130. TensorShape{oc / pack_c, kernel_h, kernel_w, ic, pack_c};
  131. }
  132. args.emplace_back(param, src_tensor_shape, weight_tensor_shape,
  133. bias_tensor_shape);
  134. };
  135. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  136. if (!no_nonlinemode) {
  137. nonlinemode.emplace_back(NLMode::RELU);
  138. nonlinemode.emplace_back(NLMode::H_SWISH);
  139. }
  140. for (auto nlmode : nonlinemode)
  141. for (size_t n : {1, 2})
  142. for (size_t kernel : kernel_vec)
  143. for (size_t oc : {4, 12, 32})
  144. for (size_t ic : {1, 3, 4, 12, 32})
  145. for (size_t h : {3, 5, 12})
  146. for (size_t w : {7, 16, 23}) {
  147. for (size_t group = 1;
  148. group <= std::min(oc, ic); ++group) {
  149. pack(n, oc, ic, h, w, kernel, stride, group,
  150. nlmode);
  151. }
  152. }
  153. return args;
  154. }
  155. std::vector<conv_bias::TestArg> get_int8_quint8_nchw44_channel_wise_args(
  156. std::vector<size_t> kernel, size_t stride, bool no_bias,
  157. bool no_nonlinemode) {
  158. using namespace conv_bias;
  159. using Param = param::ConvBias;
  160. using NLMode = param::ConvBias::NonlineMode;
  161. std::vector<TestArg> args;
  162. auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
  163. size_t stride, NLMode nlmode, bool pad) {
  164. Param param;
  165. param.stride_h = stride;
  166. param.stride_w = stride;
  167. if (pad) {
  168. param.pad_h = kernel / 2;
  169. param.pad_w = kernel / 2;
  170. } else {
  171. param.pad_h = 0;
  172. param.pad_w = 0;
  173. }
  174. param.nonlineMode = nlmode;
  175. param.format = param::ConvBias::Format::NCHW44;
  176. param.sparse = param::ConvBias::Sparse::GROUP;
  177. args.emplace_back(param, TensorShape{n, group, h, w, 4},
  178. TensorShape{group, 1, 1, kernel, kernel, 4},
  179. TensorShape{});
  180. if (!no_bias) {
  181. args.emplace_back(param, TensorShape{n, group, h, w, 4},
  182. TensorShape{group, 1, 1, kernel, kernel, 4},
  183. TensorShape{1, group, 1, 1, 4});
  184. }
  185. };
  186. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  187. if (!no_nonlinemode) {
  188. nonlinemode.emplace_back(NLMode::RELU);
  189. nonlinemode.emplace_back(NLMode::H_SWISH);
  190. }
  191. for (size_t n : {1, 2}) {
  192. for (auto nlmode : nonlinemode) {
  193. for (bool pad : {true}) {
  194. for (size_t group : {1, 2, 4, 7, 128}) {
  195. for (size_t size : {4, 5, 6, 7, 8, 9, 10, 15, 40}) {
  196. for (size_t kern : kernel) {
  197. pack(n, group, size, size, kern, stride, nlmode,
  198. pad);
  199. }
  200. }
  201. }
  202. }
  203. for (bool pad : {false}) {
  204. for (size_t group : {1, 2, 7, 128}) {
  205. for (size_t size : {7, 8, 9, 10, 15, 40}) {
  206. for (size_t kern : kernel) {
  207. pack(n, group, size, size, kern, stride, nlmode,
  208. pad);
  209. }
  210. }
  211. }
  212. }
  213. }
  214. }
  215. return args;
  216. }
  217. void checker_conv_bias_qint8x8x8(std::vector<conv_bias::TestArg> args,
  218. Handle* handle, const char* algo_name) {
  219. Checker<ConvBias> checker(handle);
  220. checker.set_before_exec_callback(
  221. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  222. #if MEGDNN_ARMV7
  223. checker.set_epsilon(1);
  224. #endif
  225. UniformIntRNG rng{-50, 50};
  226. checker.set_dtype(0, dtype::QuantizedS8(0.41113496f))
  227. .set_dtype(1, dtype::QuantizedS8(0.01887994f))
  228. .set_dtype(2, dtype::QuantizedS32(0.41113496f * 0.01887994f))
  229. .set_dtype(4, dtype::QuantizedS8(0.49550694f))
  230. .set_rng(0, &rng)
  231. .set_rng(1, &rng)
  232. .set_rng(2, &rng);
  233. for (auto&& arg : args) {
  234. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  235. }
  236. }
  237. void checker_conv_bias_qint8x8x32(std::vector<conv_bias::TestArg> args,
  238. Handle* handle, const char* algo_name) {
  239. Checker<ConvBias> checker(handle);
  240. UniformIntRNG rng{-50, 50};
  241. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  242. .set_dtype(1, dtype::QuantizedS8(2.5f))
  243. .set_dtype(2, dtype::QuantizedS32(6.25f))
  244. .set_dtype(4, {});
  245. checker.set_before_exec_callback(
  246. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  247. for (auto&& arg : args) {
  248. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  249. }
  250. }
  251. void checker_conv_bias_quint8x8x8(std::vector<conv_bias::TestArg> args,
  252. Handle* handle, const char* algo_name) {
  253. Checker<ConvBias> checker(handle);
  254. checker.set_before_exec_callback(
  255. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  256. UniformIntRNG rng(0, 255);
  257. checker.set_dtype(0, dtype::Quantized8Asymm(0.2f, 100))
  258. .set_dtype(1, dtype::Quantized8Asymm(0.2f, 120))
  259. .set_dtype(2, dtype::QuantizedS32(0.04f))
  260. .set_dtype(4, dtype::Quantized8Asymm(1.4f, 110))
  261. .set_rng(0, &rng)
  262. .set_rng(1, &rng)
  263. .set_rng(2, &rng);
  264. for (auto&& arg : args) {
  265. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  266. }
  267. }
  268. void checker_conv_bias_quint8x8x32(std::vector<conv_bias::TestArg> args,
  269. Handle* handle, const char* algo_name) {
  270. Checker<ConvBias> checker(handle);
  271. checker.set_before_exec_callback(
  272. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  273. NormalRNG rng(128.f);
  274. checker.set_rng(0, &rng).set_rng(1, &rng);
  275. checker.set_dtype(0, dtype::Quantized8Asymm(1.2f, (uint8_t)127))
  276. .set_dtype(1, dtype::Quantized8Asymm(1.3f, (uint8_t)129))
  277. .set_dtype(2, dtype::QuantizedS32(1.2 * 1.3))
  278. .set_dtype(4, {});
  279. for (auto&& arg : args) {
  280. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  281. }
  282. }
  283. void checker_conv_bias_int8x8x32_multi(std::vector<conv_bias::TestArg> args,
  284. Handle* handle, const char* algo_name) {
  285. Checker<ConvBias> checker(handle);
  286. checker.set_before_exec_callback(
  287. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  288. checker.set_dtype(0, dtype::Int8());
  289. checker.set_dtype(1, dtype::Int8());
  290. checker.set_dtype(2, dtype::Int32());
  291. checker.set_dtype(4, dtype::Int32());
  292. for (auto&& arg : args) {
  293. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  294. }
  295. }
  296. /**********************************F32 direct************************/
  297. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_LARGE_GROUP) {
  298. check_conv_bias(
  299. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  300. handle(), "F32DIRECT_LARGE_GROUP");
  301. }
  302. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_SMALL_GROUP) {
  303. check_conv_bias(
  304. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  305. handle(), "F32DIRECT_SMALL_GROUP");
  306. }
  307. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_LARGE_GROUP) {
  308. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  309. handle(), "F32STRD1_LARGE_GROUP");
  310. }
  311. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_SMALL_GROUP) {
  312. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  313. handle(), "F32STRD1_SMALL_GROUP");
  314. }
  315. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_LARGE_GROUP) {
  316. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  317. handle(), "F32STRD2_LARGE_GROUP");
  318. }
  319. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) {
  320. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  321. handle(), "F32STRD2_SMALL_GROUP");
  322. }
  323. /**********************************F16 direct************************/
  324. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  325. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_LARGE_GROUP) {
  326. NormalRNG rng(1);
  327. checker_conv_bias_f16(
  328. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  329. handle(), rng, "F16DIRECT_LARGE_GROUP", 0.03);
  330. }
  331. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_SMALL_GROUP) {
  332. NormalRNG rng(1);
  333. checker_conv_bias_f16(
  334. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  335. handle(), rng, "F16DIRECT_SMALL_GROUP", 0.03);
  336. }
  337. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_LARGE_GROUP) {
  338. NormalRNG rng(1);
  339. checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false),
  340. handle(), rng, "F16STRD1_LARGE_GROUP", 0.03);
  341. }
  342. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_SMALL_GROUP) {
  343. NormalRNG rng(1);
  344. checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false),
  345. handle(), rng, "F16STRD1_SMALL_GROUP", 0.03);
  346. }
  347. #endif
  348. /**********************************algo 8816 direct************************/
  349. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_LARGE_GROUP) {
  350. checker_conv_bias_int8x8x16(
  351. get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(),
  352. "I8816DIRECT_LARGE_GROUP");
  353. }
  354. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_SMALL_GROUP) {
  355. checker_conv_bias_int8x8x16(
  356. get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(),
  357. "I8816DIRECT_SMALL_GROUP");
  358. }
  359. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_LARGE_GROUP) {
  360. checker_conv_bias_int8x8x16(
  361. get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
  362. "I8816STRD2_LARGE_GROUP");
  363. }
  364. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_SMALL_GROUP) {
  365. checker_conv_bias_int8x8x16(
  366. get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
  367. "I8816STRD2_SMALL_GROUP");
  368. }
  369. /**********************************algo 8-8-32 direct************************/
  370. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_LARGE_GROUP) {
  371. checker_conv_bias_int8x8x32_multi(
  372. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  373. "S8STRD1_LARGE_GROUP");
  374. }
  375. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_SMALL_GROUP) {
  376. checker_conv_bias_int8x8x32_multi(
  377. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  378. "S8STRD1_SMALL_GROUP");
  379. }
  380. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_LARGE_GROUP) {
  381. checker_conv_bias_int8x8x32_multi(
  382. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  383. "S8STRD2_LARGE_GROUP");
  384. }
  385. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_SMALL_GROUP) {
  386. checker_conv_bias_int8x8x32_multi(
  387. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  388. "S8STRD2_SMALL_GROUP");
  389. }
  390. TEST_F(ARM_COMMON_MULTI_THREADS,
  391. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT1_NCHW44) {
  392. checker_conv_bias_int8x8x32_multi(
  393. get_int8_quint8_nchw44_channel_wise_args({2, 3, 5}, 1, false, true),
  394. handle(), "S8_CHAN_WISE_STRD1_NCHW44");
  395. }
  396. TEST_F(ARM_COMMON_MULTI_THREADS,
  397. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT2_NCHW44) {
  398. checker_conv_bias_int8x8x32_multi(
  399. get_int8_quint8_nchw44_channel_wise_args({2, 3, 5}, 2, false, true),
  400. handle(), "S8_CHAN_WISE_STRD2_NCHW44");
  401. }
  402. /********************************qint8 direct******************************/
  403. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_LARGE_GROUP) {
  404. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  405. {2, 3, 5, 7}, 1, false, false, false),
  406. handle(), "S8STRD1_LARGE_GROUP");
  407. }
  408. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_SMALL_GROUP) {
  409. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  410. {2, 3, 5, 7}, 1, false, false, false),
  411. handle(), "S8STRD1_SMALL_GROUP");
  412. }
  413. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_LARGE_GROUP) {
  414. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  415. {2, 3, 5, 7}, 2, false, false, false),
  416. handle(), "S8STRD2_LARGE_GROUP");
  417. }
  418. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_SMALL_GROUP) {
  419. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  420. {2, 3, 5, 7}, 2, false, false, false),
  421. handle(), "S8STRD2_SMALL_GROUP");
  422. }
  423. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) {
  424. checker_conv_bias_qint8x8x8(
  425. get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  426. handle(), "S8_NCHW44_DIRECT_STRD1");
  427. }
  428. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) {
  429. checker_conv_bias_qint8x8x8(
  430. get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  431. handle(), "S8_NCHW44_DIRECT_STRD2");
  432. }
  433. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) {
  434. checker_conv_bias_qint8x8x8(get_int8_quint8_nchw44_channel_wise_args(
  435. {2, 3, 5}, 1, false, false),
  436. handle(), "S8_CHAN_WISE_STRD1_NCHW44");
  437. }
  438. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) {
  439. checker_conv_bias_qint8x8x8(get_int8_quint8_nchw44_channel_wise_args(
  440. {2, 3, 5}, 2, false, false),
  441. handle(), "S8_CHAN_WISE_STRD2_NCHW44");
  442. }
  443. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44) {
  444. checker_conv_bias_qint8x8x8(
  445. get_nchw44_conv_bias_args({3, 5, 7}, 2, false, false, false, true),
  446. handle(), "S8_CONV_NCHW_NCHW44");
  447. }
  448. /*****************************quint8 direct****************************/
  449. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_LARGE_GROUP) {
  450. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  451. {2, 3, 5, 7}, 1, false, false, false),
  452. handle(), "QU8STRD1_LARGE_GROUP");
  453. }
  454. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_SMALL_GROUP) {
  455. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  456. {2, 3, 5, 7}, 1, false, false, false),
  457. handle(), "QU8STRD1_SMALL_GROUP");
  458. }
  459. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_LARGE_GROUP) {
  460. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  461. {2, 3, 5, 7}, 2, false, false, false),
  462. handle(), "QU8STRD2_LARGE_GROUP");
  463. }
  464. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) {
  465. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  466. {2, 3, 5, 7}, 2, false, false, false),
  467. handle(), "QU8STRD2_SMALL_GROUP");
  468. }
  469. /****************************dot qint8 direct*************************/
  470. #if __ARM_FEATURE_DOTPROD
  471. TEST_F(ARM_COMMON_MULTI_THREADS,
  472. CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) {
  473. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  474. {2, 3, 5, 7}, 1, false, false, false),
  475. handle(), "ARMDOTS8STRD1_LARGE_GROUP");
  476. }
  477. TEST_F(ARM_COMMON_MULTI_THREADS,
  478. CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) {
  479. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  480. {2, 3, 5, 7}, 1, false, false, false),
  481. handle(), "ARMDOTS8STRD1_SMALL_GROUP");
  482. }
  483. TEST_F(ARM_COMMON_MULTI_THREADS,
  484. CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) {
  485. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  486. {2, 3, 5, 7}, 2, false, false, false),
  487. handle(), "ARMDOTS8STRD2_LARGE_GROUP");
  488. }
  489. TEST_F(ARM_COMMON_MULTI_THREADS,
  490. CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) {
  491. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  492. {2, 3, 5, 7}, 2, false, false, false),
  493. handle(), "ARMDOTS8STRD2_SMALL_GROUP");
  494. }
  495. /****************************dot 8-8-32 direct*************************/
  496. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_LARGE_GROUP) {
  497. checker_conv_bias_qint8x8x32(
  498. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  499. "ARMDOTS8STRD1_LARGE_GROUP");
  500. }
  501. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_SMALL_GROUP) {
  502. checker_conv_bias_qint8x8x32(
  503. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  504. "ARMDOTS8STRD1_SMALL_GROUP");
  505. }
  506. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_LARGE_GROUP) {
  507. checker_conv_bias_qint8x8x32(
  508. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  509. "ARMDOTS8STRD2_LARGE_GROUP");
  510. }
  511. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_SMALL_GROUP) {
  512. checker_conv_bias_qint8x8x32(
  513. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  514. "ARMDOTS8STRD2_SMALL_GROUP");
  515. }
  516. /******************************dot quint8*****************************/
  517. TEST_F(ARM_COMMON_MULTI_THREADS,
  518. CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) {
  519. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  520. {2, 3, 5, 7}, 1, false, false, false),
  521. handle(), "ARMDOTU8STRD1_LARGE_GROUP");
  522. }
  523. TEST_F(ARM_COMMON_MULTI_THREADS,
  524. CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) {
  525. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  526. {2, 3, 5, 7}, 1, false, false, false),
  527. handle(), "ARMDOTU8STRD1_SMALL_GROUP");
  528. }
  529. TEST_F(ARM_COMMON_MULTI_THREADS,
  530. CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) {
  531. checker_conv_bias_quint8x8x8(
  532. get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false),
  533. handle(), "ARMDOTU8STRD2_LARGE_GROUP");
  534. }
  535. TEST_F(ARM_COMMON_MULTI_THREADS,
  536. CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) {
  537. checker_conv_bias_quint8x8x8(
  538. get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false),
  539. handle(), "ARMDOTU8STRD2_SMALL_GROUP");
  540. }
  541. /******************************dot quint8x8x32***********************/
  542. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_LARGE_GROUP) {
  543. checker_conv_bias_quint8x8x32(
  544. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  545. "ARMDOTU8STRD1_LARGE_GROUP");
  546. }
  547. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_SMALL_GROUP) {
  548. checker_conv_bias_quint8x8x32(
  549. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  550. "ARMDOTU8STRD1_SMALL_GROUP");
  551. }
  552. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_LARGE_GROUP) {
  553. checker_conv_bias_quint8x8x32(
  554. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  555. "ARMDOTU8STRD2_LARGE_GROUP");
  556. }
  557. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_SMALL_GROUP) {
  558. checker_conv_bias_quint8x8x32(
  559. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  560. "ARMDOTU8STRD2_SMALL_GROUP");
  561. }
  562. #endif
  563. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4) {
  564. using namespace conv_bias;
  565. std::vector<TestArg> args = get_winograd_mk_packed_args();
  566. Checker<ConvBiasForward> checker(handle());
  567. check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4);
  568. }
  569. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63) {
  570. using namespace conv_bias;
  571. std::vector<TestArg> args = get_winograd_args(3);
  572. Checker<ConvBiasForward> checker(handle());
  573. check_winograd("1:6:32", checker, args);
  574. }
  575. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) {
  576. using namespace conv_bias;
  577. std::vector<TestArg> args = get_winograd_mk_packed_args();
  578. Checker<ConvBiasForward> checker(handle());
  579. check_winograd("4:6:32", checker, args, param::MatrixMul::Format::MK4);
  580. }
  581. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54) {
  582. using namespace conv_bias;
  583. std::vector<TestArg> args = get_winograd_args(4);
  584. Checker<ConvBiasForward> checker(handle());
  585. check_winograd("1:5:32", checker, args);
  586. }
  587. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45) {
  588. using namespace conv_bias;
  589. std::vector<TestArg> args = get_winograd_args(5);
  590. Checker<ConvBiasForward> checker(handle());
  591. check_winograd("1:4:32", checker, args);
  592. }
  593. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) {
  594. using namespace conv_bias;
  595. std::vector<TestArg> args = get_winograd_args(3);
  596. Checker<ConvBiasForward> checker(handle());
  597. auto extra_impl = [](const TensorNDArray& tensors, uint32_t m,
  598. param::ConvBias param, Handle* handle) {
  599. megdnn_assert(param.format == param::ConvBias::Format::NCHW);
  600. auto winograd_preprocess_opr =
  601. handle->create_operator<WinogradFilterPreprocess>();
  602. winograd_preprocess_opr->param().output_block_size = m;
  603. TensorLayout filter_transform_layout;
  604. winograd_preprocess_opr->deduce_layout(tensors[1].layout,
  605. filter_transform_layout);
  606. size_t winograd_preprocess_workspace_in_bytes =
  607. winograd_preprocess_opr->get_workspace_in_bytes(
  608. tensors[1].layout, filter_transform_layout);
  609. auto conv_bias_opr = handle->create_operator<ConvBias>();
  610. conv_bias_opr->param() = param;
  611. conv_bias_opr->param().format = param::ConvBias::Format::NCHW_WINOGRAD;
  612. conv_bias_opr->param().output_block_size = m;
  613. size_t conv_bias_workspace_in_bytes =
  614. conv_bias_opr->get_workspace_in_bytes(
  615. tensors[0].layout, filter_transform_layout,
  616. tensors[2].layout, tensors[3].layout,
  617. tensors[4].layout);
  618. WorkspaceBundle wb(nullptr, {filter_transform_layout.span().dist_byte(),
  619. conv_bias_workspace_in_bytes,
  620. winograd_preprocess_workspace_in_bytes});
  621. wb.set(malloc(wb.total_size_in_bytes()));
  622. TensorND filter_transform_tensor(wb.get(0),
  623. std::move(filter_transform_layout));
  624. winograd_preprocess_opr->exec(tensors[1], filter_transform_tensor,
  625. wb.get_workspace(2));
  626. conv_bias_opr->exec(tensors[0], filter_transform_tensor, tensors[2],
  627. tensors[3], tensors[4], wb.get_workspace(1));
  628. free(wb.ptr());
  629. };
  630. auto run = [&checker, &extra_impl](
  631. Handle* handle, const std::vector<TestArg>& args,
  632. const std::vector<size_t>& out_size, DType A_dtype,
  633. DType B_dtype, DType C_dtype, DType D_dtype,
  634. const float eps) {
  635. for (auto&& arg : args) {
  636. for (uint32_t m : out_size) {
  637. checker.set_extra_opr_impl(std::bind(extra_impl,
  638. std::placeholders::_1, m,
  639. arg.param, handle));
  640. checker.set_dtype(0, A_dtype)
  641. .set_dtype(1, B_dtype)
  642. .set_dtype(2, C_dtype)
  643. .set_dtype(4, D_dtype)
  644. .set_epsilon(eps)
  645. .set_param(arg.param)
  646. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  647. }
  648. }
  649. };
  650. run(handle(), args, {6}, dtype::Float32(), dtype::Float32(),
  651. dtype::Float32(), dtype::Float32(), 1e-3f);
  652. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  653. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  654. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  655. run(handle(), args, {6}, dtype::Float16(), dtype::Float16(),
  656. dtype::Float16(), dtype::Float16(), 0.35f);
  657. #endif
  658. }
  659. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1) {
  660. using namespace conv_bias;
  661. Checker<ConvBiasForward> checker(handle());
  662. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  663. const std::vector<size_t>& out_size, DType A_dtype,
  664. DType B_dtype, DType C_dtype, DType D_dtype,
  665. param::MatrixMul::Format format, float eps) {
  666. for (auto&& arg : args) {
  667. for (uint32_t m : out_size) {
  668. checker.set_extra_opr_impl(std::bind(
  669. winograd_algo_extra_impl, std::placeholders::_1, m,
  670. arg.param, handle, format));
  671. checker.set_dtype(0, A_dtype)
  672. .set_dtype(1, B_dtype)
  673. .set_dtype(2, C_dtype)
  674. .set_dtype(4, D_dtype)
  675. .set_epsilon(eps)
  676. .set_param(arg.param)
  677. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  678. }
  679. }
  680. };
  681. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  682. std::vector<TestArg> args_first_half(args.begin(),
  683. args.begin() + args.size() / 2);
  684. run(handle(), args_first_half, {2, 6}, dtype::Float32{}, dtype::Float32{},
  685. dtype::Float32{}, dtype::Float32{}, param::MatrixMul::Format::MK4,
  686. 1e-3f);
  687. }
  688. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2) {
  689. using namespace conv_bias;
  690. Checker<ConvBiasForward> checker(handle());
  691. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  692. const std::vector<size_t>& out_size, DType A_dtype,
  693. DType B_dtype, DType C_dtype, DType D_dtype,
  694. param::MatrixMul::Format format, float eps) {
  695. for (auto&& arg : args) {
  696. for (uint32_t m : out_size) {
  697. checker.set_extra_opr_impl(std::bind(
  698. winograd_algo_extra_impl, std::placeholders::_1, m,
  699. arg.param, handle, format));
  700. checker.set_dtype(0, A_dtype)
  701. .set_dtype(1, B_dtype)
  702. .set_dtype(2, C_dtype)
  703. .set_dtype(4, D_dtype)
  704. .set_epsilon(eps)
  705. .set_param(arg.param)
  706. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  707. }
  708. }
  709. };
  710. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  711. std::vector<TestArg> args_second_half(args.begin() + args.size() / 2,
  712. args.end());
  713. run(handle(), args_second_half, {2, 6}, dtype::Float32{}, dtype::Float32{},
  714. dtype::Float32{}, dtype::Float32{}, param::MatrixMul::Format::MK4,
  715. 1e-3f);
  716. }
  717. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  718. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F16) {
  719. using namespace conv_bias;
  720. Checker<ConvBiasForward> checker(handle());
  721. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  722. const std::vector<size_t>& out_size, DType A_dtype,
  723. DType B_dtype, DType C_dtype, DType D_dtype,
  724. param::MatrixMul::Format format, float eps) {
  725. for (auto&& arg : args) {
  726. for (uint32_t m : out_size) {
  727. checker.set_extra_opr_impl(std::bind(
  728. winograd_algo_extra_impl, std::placeholders::_1, m,
  729. arg.param, handle, format));
  730. checker.set_dtype(0, A_dtype)
  731. .set_dtype(1, B_dtype)
  732. .set_dtype(2, C_dtype)
  733. .set_dtype(4, D_dtype)
  734. .set_epsilon(eps)
  735. .set_param(arg.param)
  736. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  737. }
  738. }
  739. };
  740. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  741. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  742. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  743. run(handle(), args, {2}, dtype::Float16{}, dtype::Float16{},
  744. dtype::Float16{}, dtype::Float16{}, param::MatrixMul::Format::MK8,
  745. 0.25);
  746. }
  747. #endif
  748. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_INT8) {
  749. using namespace conv_bias;
  750. Checker<ConvBiasForward> checker(handle());
  751. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  752. const std::vector<size_t>& out_size, DType A_dtype,
  753. DType B_dtype, DType C_dtype, DType D_dtype,
  754. param::MatrixMul::Format format, float eps) {
  755. for (auto&& arg : args) {
  756. for (uint32_t m : out_size) {
  757. checker.set_extra_opr_impl(std::bind(
  758. winograd_algo_extra_impl, std::placeholders::_1, m,
  759. arg.param, handle, format));
  760. checker.set_dtype(0, A_dtype)
  761. .set_dtype(1, B_dtype)
  762. .set_dtype(2, C_dtype)
  763. .set_dtype(4, D_dtype)
  764. .set_epsilon(eps)
  765. .set_param(arg.param)
  766. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  767. }
  768. }
  769. };
  770. #if MEGDNN_AARCH64
  771. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  772. #else
  773. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  774. #endif
  775. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  776. ssprintf("WINOGRAD:%s:8:2:32", matmul_name).c_str()));
  777. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  778. std::vector<TestArg> quantized_args =
  779. get_quantized_winograd_mk_packed_args(8);
  780. UniformIntRNG int_rng{-50, 50};
  781. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  782. run(handle(), quantized_args, {2}, dtype::QuantizedS8(2.5f),
  783. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f),
  784. dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3);
  785. }
  786. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  787. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F23) {
  788. using namespace conv_bias;
  789. std::vector<TestArg> args = get_winograd_mk_packed_args();
  790. Checker<ConvBiasForward> checker(handle());
  791. check_winograd_fp16("1:2:32", checker, args, NULL, 0.08);
  792. }
  793. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_1) {
  794. using namespace conv_bias;
  795. std::vector<TestArg> args = get_winograd_args(5);
  796. std::vector<TestArg> args_head_half(args.begin(),
  797. args.begin() + args.size() / 2);
  798. Checker<ConvBiasForward> checker(handle());
  799. //! fp16 range -1.0 ~ 1.0
  800. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  801. check_winograd_fp16("1:4:32", checker, args_head_half, rng, 0.25);
  802. }
  803. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_2) {
  804. using namespace conv_bias;
  805. std::vector<TestArg> args = get_winograd_args(5);
  806. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2,
  807. args.end());
  808. Checker<ConvBiasForward> checker(handle());
  809. //! fp16 range -1.0 ~ 1.0
  810. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  811. check_winograd_fp16("1:4:32", checker, args_back_half, rng, 0.25);
  812. }
  813. //! FIXME: This test may be failed if run `ARM_COMMON.CONV_BIAS_WINOGRAD*`, but
  814. //! it will pass when run single testcase
  815. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F63) {
  816. using namespace conv_bias;
  817. std::vector<TestArg> args = get_winograd_args(3);
  818. Checker<ConvBiasForward> checker(handle());
  819. //! fp16 range -1.0 ~ 1.0
  820. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  821. check_winograd_fp16("1:6:32", checker, args, rng, 0.3);
  822. }
  823. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_1) {
  824. using namespace conv_bias;
  825. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  826. std::vector<TestArg> args_head_half(args.begin(),
  827. args.begin() + args.size() / 2);
  828. Checker<ConvBiasForward> checker(handle());
  829. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  830. check_winograd_fp16("8:2:32", checker, args_head_half, rng, 0.25,
  831. param::MatrixMul::Format::MK8);
  832. }
  833. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_2) {
  834. using namespace conv_bias;
  835. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  836. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2,
  837. args.end());
  838. Checker<ConvBiasForward> checker(handle());
  839. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  840. check_winograd_fp16("8:2:32", checker, args_back_half, rng, 0.25,
  841. param::MatrixMul::Format::MK8);
  842. }
  843. #endif
  844. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_INT8_8X8) {
  845. using namespace conv_bias;
  846. std::vector<TestArg> args = get_quantized_winograd_mk_packed_args(8);
  847. Checker<ConvBiasForward> checker(handle());
  848. UniformIntRNG rng{-50, 50};
  849. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  850. .set_dtype(1, dtype::QuantizedS8(2.5f))
  851. .set_dtype(2, dtype::QuantizedS32(6.25f))
  852. .set_dtype(4, dtype::QuantizedS8(60.25f))
  853. .set_rng(0, &rng)
  854. .set_rng(1, &rng)
  855. .set_rng(2, &rng);
  856. check_winograd("8:2:32", checker, args, param::MatrixMul::Format::MK8);
  857. }
  858. void checker_conv_bias(std::vector<conv_bias::TestArg> args, Handle* handle,
  859. RNG* rng, float epsilon, DType type0, DType type1,
  860. DType type2, DType type3, const char* algo_name) {
  861. using namespace conv_bias;
  862. Checker<ConvBias> checker(handle);
  863. checker.set_before_exec_callback(
  864. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  865. checker.set_dtype(0, type0);
  866. checker.set_dtype(1, type1);
  867. checker.set_dtype(2, type2);
  868. checker.set_dtype(4, type3);
  869. checker.set_epsilon(epsilon);
  870. if (NULL != rng) {
  871. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng).set_rng(3, rng);
  872. }
  873. for (auto&& arg : args) {
  874. checker.set_param(arg.param).execs(
  875. {arg.src, arg.filter, arg.bias, {}, {}});
  876. }
  877. }
  878. // clang-format off
  879. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE2) {
  880. #define cb(name) \
  881. check_conv_bias( \
  882. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 2, false, false, false), \
  883. handle(), name);
  884. #if MEGDNN_AARCH64
  885. cb("IM2COLMATMUL:AARCH64_F32K8X12X1")
  886. cb("IM2COLMATMUL:AARCH64_F32K4X16X1")
  887. cb("IM2COLMATMUL:FB_F32_K8X12X1")
  888. #elif MEGDNN_ARMV7
  889. cb("IM2COLMATMUL:ARMV7_F32")
  890. #endif
  891. #undef cb
  892. }
  893. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE1) {
  894. #define cb(name) \
  895. check_conv_bias( \
  896. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, false), \
  897. handle(), name);
  898. #if MEGDNN_AARCH64
  899. cb("IM2COLMATMUL:AARCH64_F32K8X12X1")
  900. cb("IM2COLMATMUL:AARCH64_F32K4X16X1")
  901. cb("IM2COLMATMUL:FB_F32_K8X12X1")
  902. #elif MEGDNN_ARMV7
  903. cb("IM2COLMATMUL:ARMV7_F32")
  904. cb("IM2COLMATMUL:FB_F32_K8X12X1")
  905. #endif
  906. #undef cb
  907. }
  908. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) {
  909. UniformIntRNG rng{-50, 50};
  910. #define cb(name) \
  911. checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
  912. false, true, true), \
  913. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  914. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  915. dtype::QuantizedS8(60.25f), name); \
  916. checker_conv_bias( \
  917. get_conv_bias_args({1}, 2, false, false, false, true, true), \
  918. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  919. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  920. dtype::QuantizedS8(60.25f), name);
  921. float epsilon = 0.001;
  922. #if MEGDNN_AARCH64
  923. #if __ARM_FEATURE_DOTPROD
  924. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD");
  925. #else
  926. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8");
  927. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16");
  928. #endif
  929. #elif MEGDNN_ARMV7
  930. epsilon = 1;
  931. cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8");
  932. #endif
  933. #undef cb
  934. }
  935. // clang-format on
  936. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  937. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) {
  938. NormalRNG rng(128.f);
  939. #define cb(name) \
  940. checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
  941. false, true, true), \
  942. handle(), &rng, epsilon, \
  943. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  944. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  945. dtype::QuantizedS32(1.2 * 1.3), \
  946. dtype::Quantized8Asymm(50.3f, (uint8_t)120), name); \
  947. checker_conv_bias( \
  948. get_conv_bias_args({1}, 2, false, false, false, true, true), \
  949. handle(), &rng, epsilon, \
  950. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  951. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  952. dtype::QuantizedS32(1.2 * 1.3), \
  953. dtype::Quantized8Asymm(50.3f, (uint8_t)120), name);
  954. float epsilon = 0.001;
  955. #if MEGDNN_AARCH64
  956. #if __ARM_FEATURE_DOTPROD
  957. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD");
  958. #else
  959. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8");
  960. #endif
  961. #elif MEGDNN_ARMV7
  962. epsilon = 1;
  963. cb("IM2COLMATMUL:ARMV7_QUINT8_K4X8X8");
  964. #endif
  965. #undef cb
  966. }
  967. #endif
  968. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  969. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) {
  970. UniformIntRNG rng{-50, 50};
  971. float epsilon = 0.001;
  972. #define cb(name) \
  973. checker_conv_bias( \
  974. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \
  975. handle(), &rng, epsilon, \
  976. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  977. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  978. dtype::QuantizedS32(1.2 * 1.3), {}, name); \
  979. checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \
  980. &rng, epsilon, \
  981. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  982. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  983. dtype::QuantizedS32(1.2 * 1.3), {}, name);
  984. #if MEGDNN_AARCH64
  985. #if __ARM_FEATURE_DOTPROD
  986. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD");
  987. #else
  988. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8");
  989. #endif
  990. #elif MEGDNN_ARMV7
  991. #if __ARM_FEATURE_DOTPROD
  992. cb("IM2COLMATMUL:AARCH32_QUINT8_K4X8X4");
  993. #endif
  994. cb("IM2COLMATMUL:ARMV7_QUINT8_K4X8X8");
  995. #endif
  996. #undef cb
  997. }
  998. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) {
  999. UniformIntRNG rng{-50, 50};
  1000. float epsilon = 0.001;
  1001. #define cb(name) \
  1002. checker_conv_bias( \
  1003. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \
  1004. handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
  1005. dtype::Int16{}, dtype::Int16{}, name); \
  1006. checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \
  1007. &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
  1008. dtype::Int16{}, dtype::Int16{}, name);
  1009. #if MEGDNN_AARCH64
  1010. cb("IM2COLMATMUL:AARCH64_INT8X8X16_K8X8X8");
  1011. cb("IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16");
  1012. cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16");
  1013. #elif MEGDNN_ARMV7
  1014. cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16");
  1015. cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8");
  1016. cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X2X16");
  1017. #endif
  1018. #undef cb
  1019. }
  1020. #endif
  1021. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  1022. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP16) {
  1023. using namespace conv_bias;
  1024. param::ConvBias cur_param;
  1025. std::vector<conv_bias::TestArg> args =
  1026. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, false);
  1027. std::vector<conv_bias::TestArg> args1 =
  1028. get_conv_bias_args({1}, 2, false, false, false);
  1029. args.insert(args.begin(), args1.begin(), args1.end());
  1030. NormalRNG rng(1);
  1031. #define cb(name) \
  1032. checker_conv_bias(args, handle(), &rng, 0.03, dtype::Float16{}, \
  1033. dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, \
  1034. name);
  1035. #if MEGDNN_AARCH64
  1036. cb("IM2COLMATMUL:AARCH64_F16_K8X24X1");
  1037. #elif MEGDNN_ARMV7
  1038. cb("IM2COLMATMUL:AARCH32_F16_K4X16X1");
  1039. #endif
  1040. #undef cb
  1041. }
  1042. #endif
  1043. void checker_conv_bias_mul_int8x8x32(std::vector<conv_bias::TestArg> args,
  1044. Handle* handle, const char* algo_name) {
  1045. using namespace conv_bias;
  1046. Checker<ConvBias> checker(handle);
  1047. checker.set_before_exec_callback(
  1048. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  1049. checker.set_dtype(0, dtype::Int8());
  1050. checker.set_dtype(1, dtype::Int8());
  1051. checker.set_dtype(2, dtype::Int32());
  1052. checker.set_dtype(4, dtype::Int32());
  1053. for (auto&& arg : args) {
  1054. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  1055. }
  1056. UniformIntRNG rng{-50, 50};
  1057. for (auto&& arg : args) {
  1058. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  1059. .set_dtype(1, dtype::QuantizedS8(2.5f))
  1060. .set_dtype(2, dtype::QuantizedS32(6.25f))
  1061. .set_dtype(4, {})
  1062. .set_rng(0, &rng)
  1063. .set_rng(1, &rng)
  1064. .set_rng(2, &rng)
  1065. .set_param(arg.param)
  1066. .execs({arg.src, arg.filter, {}, {}, {}});
  1067. }
  1068. }
  1069. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1070. #if !__ARM_FEATURE_DOTPROD
  1071. TEST_F(ARM_COMMON, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44) {
  1072. using namespace conv_bias;
  1073. std::vector<conv_bias::TestArg> args =
  1074. get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true);
  1075. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1076. #if MEGDNN_AARCH64
  1077. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1078. #else
  1079. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1080. #endif
  1081. #undef cb
  1082. }
  1083. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_MULTI) {
  1084. using namespace conv_bias;
  1085. std::vector<conv_bias::TestArg> args =
  1086. get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true);
  1087. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1088. #if MEGDNN_AARCH64
  1089. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1090. #else
  1091. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1092. #endif
  1093. #undef cb
  1094. }
  1095. TEST_F(ARM_COMMON, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44) {
  1096. UniformIntRNG rng{-50, 50};
  1097. #define cb(name) \
  1098. checker_conv_bias(get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1), \
  1099. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1100. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1101. dtype::QuantizedS8(60.25f), name);
  1102. float epsilon = 0.001;
  1103. #if MEGDNN_AARCH64
  1104. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1105. #else
  1106. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1107. #endif
  1108. #undef cb
  1109. }
  1110. TEST_F(ARM_COMMON_MULTI_THREADS,
  1111. CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_MULTI) {
  1112. UniformIntRNG rng{-50, 50};
  1113. #define cb(name) \
  1114. checker_conv_bias(get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1), \
  1115. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1116. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1117. dtype::QuantizedS8(60.25f), name);
  1118. float epsilon = 0.001;
  1119. #if MEGDNN_AARCH64
  1120. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1121. #else
  1122. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1123. #endif
  1124. #undef cb
  1125. }
  1126. #if MEGDNN_AARCH64
  1127. TEST_F(ARM_COMMON_MULTI_THREADS,
  1128. CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_FUSE) {
  1129. UniformIntRNG rng{-50, 50};
  1130. #define cb(name) \
  1131. checker_conv_bias(get_nchw44_conv_bias_args({3}, 1), handle(), &rng, \
  1132. epsilon, dtype::QuantizedS8(2.5f), \
  1133. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1134. dtype::QuantizedS8(60.25f), name);
  1135. float epsilon = 0.001;
  1136. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1137. #undef cb
  1138. }
  1139. #endif
  1140. #endif
  1141. #endif
  1142. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
  1143. using namespace conv_bias;
  1144. std::vector<conv_bias::TestArg> args =
  1145. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true);
  1146. std::vector<conv_bias::TestArg> args1 =
  1147. get_conv_bias_args({1}, 2, false, true, true);
  1148. args.insert(args.begin(), args1.begin(), args1.end());
  1149. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1150. #if MEGDNN_AARCH64
  1151. #if __ARM_FEATURE_DOTPROD
  1152. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD");
  1153. #else
  1154. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8");
  1155. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16");
  1156. #endif
  1157. #elif MEGDNN_ARMV7
  1158. #if __ARM_FEATURE_DOTPROD
  1159. cb("IM2COLMATMUL:AARCH32_INT8_K6X8X4");
  1160. #endif
  1161. cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8");
  1162. #endif
  1163. #if MEGDNN_ARMV7
  1164. cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X2X16");
  1165. #endif
  1166. #undef cb
  1167. }
  1168. /***************************** Conv1x1 Algo Test ***********************/
  1169. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) {
  1170. using namespace conv_bias;
  1171. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, false);
  1172. #if MEGDNN_AARCH64
  1173. check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32K8X12X1:24");
  1174. #elif MEGDNN_ARMV7
  1175. check_conv_bias(args, handle(), "CONV1x1:ARMV7_F32:48");
  1176. #endif
  1177. }
  1178. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  1179. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16) {
  1180. using namespace conv_bias;
  1181. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, false);
  1182. NormalRNG rng(1);
  1183. #if MEGDNN_AARCH64
  1184. checker_conv_bias(args, handle(), &rng, 0.03, dtype::Float16{},
  1185. dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
  1186. "CONV1x1:AARCH64_F16_K8X24X1:48");
  1187. #elif MEGDNN_ARMV7
  1188. checker_conv_bias(args, handle(), &rng, 0.03, dtype::Float16{},
  1189. dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
  1190. "CONV1x1:AARCH32_F16_K4X16X1:24");
  1191. #endif
  1192. }
  1193. #endif
  1194. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM) {
  1195. UniformIntRNG rng{-50, 50};
  1196. float epsilon = 0.001;
  1197. #define cb(name) \
  1198. checker_conv_bias(get_conv_bias_1x1_args(false, false, true, true), \
  1199. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1200. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1201. dtype::QuantizedS8(60.25f), name);
  1202. #if MEGDNN_AARCH64
  1203. #if __ARM_FEATURE_DOTPROD
  1204. cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:24");
  1205. #else
  1206. cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24");
  1207. cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:48");
  1208. #endif
  1209. #elif MEGDNN_ARMV7
  1210. epsilon = 1;
  1211. cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:48");
  1212. #endif
  1213. #undef cb
  1214. }
  1215. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1216. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) {
  1217. NormalRNG rng(128.f);
  1218. #define cb(name) \
  1219. checker_conv_bias(get_conv_bias_1x1_args(false, false, true, true), \
  1220. handle(), &rng, epsilon, \
  1221. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1222. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1223. dtype::QuantizedS32(1.2 * 1.3), \
  1224. dtype::Quantized8Asymm(50.3f, (uint8_t)120), name);
  1225. float epsilon = 0.001;
  1226. #if MEGDNN_AARCH64
  1227. #if __ARM_FEATURE_DOTPROD
  1228. cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:48");
  1229. #else
  1230. cb("CONV1x1:AARCH64_QUINT8_K8X8X8:24");
  1231. #endif
  1232. #elif MEGDNN_ARMV7
  1233. epsilon = 1;
  1234. cb("CONV1x1:ARMV7_QUINT8_K4X8X8:48");
  1235. #endif
  1236. #undef cb
  1237. }
  1238. #endif
  1239. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1240. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32) {
  1241. UniformIntRNG rng{-50, 50};
  1242. float epsilon = 0.001;
  1243. #define cb(name) \
  1244. checker_conv_bias(get_conv_bias_1x1_args(true, true), handle(), &rng, \
  1245. epsilon, dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1246. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1247. dtype::QuantizedS32(1.2 * 1.3), {}, name);
  1248. #if MEGDNN_AARCH64
  1249. #if __ARM_FEATURE_DOTPROD
  1250. cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:24");
  1251. #else
  1252. cb("CONV1x1:AARCH64_QUINT8_K8X8X8:48");
  1253. #endif
  1254. #elif MEGDNN_ARMV7
  1255. #if __ARM_FEATURE_DOTPROD
  1256. cb("CONV1x1:AARCH32_QUINT8_K4X8X4:48");
  1257. #endif
  1258. cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24");
  1259. #endif
  1260. #undef cb
  1261. }
  1262. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) {
  1263. UniformIntRNG rng{-50, 50};
  1264. float epsilon = 0.001;
  1265. #define cb(name) \
  1266. checker_conv_bias(get_conv_bias_1x1_args(true, true), handle(), &rng, \
  1267. epsilon, dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, \
  1268. dtype::Int16{}, name);
  1269. #if MEGDNN_AARCH64
  1270. cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24");
  1271. cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24");
  1272. #elif MEGDNN_ARMV7
  1273. cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24");
  1274. cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48");
  1275. #endif
  1276. cb("CONV1x1:ARM_COMMON_INT8X8X16:48");
  1277. #undef cb
  1278. }
  1279. #endif
  1280. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) {
  1281. using namespace conv_bias;
  1282. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true);
  1283. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1284. #if MEGDNN_AARCH64
  1285. #if __ARM_FEATURE_DOTPROD
  1286. cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:48");
  1287. #else
  1288. cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24");
  1289. cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:24");
  1290. #endif
  1291. #elif MEGDNN_ARMV7
  1292. #if __ARM_FEATURE_DOTPROD
  1293. cb("CONV1x1:AARCH32_INT8_K6X8X4:48");
  1294. #endif
  1295. cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:24");
  1296. #endif
  1297. #if MEGDNN_ARMV7
  1298. cb("CONV1x1:ARMV7_INT8X8X32_K4X2X16:48");
  1299. #endif
  1300. #undef cb
  1301. }
  1302. #ifndef __ARM_FEATURE_DOTPROD
  1303. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) {
  1304. using namespace conv_bias;
  1305. std::vector<conv_bias::TestArg> args =
  1306. get_nchw44_conv_bias_args({1}, 1, true, true, true);
  1307. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1308. #if MEGDNN_AARCH64
  1309. cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24");
  1310. #elif MEGDNN_ARMV7
  1311. cb("CONV1x1:ARMV7_INT8X8X32_MK4_4X2X16:24");
  1312. #endif
  1313. #undef cb
  1314. UniformIntRNG rng{-50, 50};
  1315. float epsilon = 0.001;
  1316. #define cb(name) \
  1317. checker_conv_bias(get_nchw44_conv_bias_args({1}, 1, true, false, false), \
  1318. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1319. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1320. dtype::QuantizedS8(60.25f), name);
  1321. #if MEGDNN_AARCH64
  1322. cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24");
  1323. #elif MEGDNN_ARMV7
  1324. cb("CONV1x1:ARMV7_INT8X8X32_MK4_4X2X16:24");
  1325. #endif
  1326. #undef cb
  1327. }
  1328. #endif
  1329. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台