You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bias_multi_thread.cpp 59 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422
  1. /**
  2. * \file dnn/test/arm_common/conv_bias_multi_thread.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "test/arm_common/fixture.h"
  13. #include "test/common/benchmarker.h"
  14. #include "test/common/conv_bias.h"
  15. using namespace megdnn;
  16. using namespace test;
  17. using namespace conv_bias;
  18. std::vector<conv_bias::TestArg> get_int8_quint8_conv_bias_args(
  19. std::vector<size_t> kernel, size_t stride, bool no_pad, bool no_bias,
  20. bool no_nonlinemode) {
  21. using namespace conv_bias;
  22. using Param = param::ConvBias;
  23. using NLMode = param::ConvBias::NonlineMode;
  24. std::vector<TestArg> args;
  25. auto pack = [&](size_t n, size_t oc, size_t ic, size_t w, size_t h,
  26. size_t kernel, size_t stride, NLMode nlmode) {
  27. Param param;
  28. param.stride_h = stride;
  29. param.stride_w = stride;
  30. if (!no_pad) {
  31. param.pad_h = kernel / 2;
  32. param.pad_w = kernel / 2;
  33. } else {
  34. param.pad_h = 0;
  35. param.pad_w = 0;
  36. }
  37. param.nonlineMode = nlmode;
  38. args.emplace_back(param, TensorShape{n, ic, h, w},
  39. TensorShape{oc, ic, kernel, kernel}, TensorShape{});
  40. if (!no_bias) {
  41. args.emplace_back(param, TensorShape{n, ic, h, w},
  42. TensorShape{oc, ic, kernel, kernel},
  43. TensorShape{1, oc, 1, 1});
  44. }
  45. };
  46. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  47. if (!no_nonlinemode) {
  48. nonlinemode.emplace_back(NLMode::RELU);
  49. nonlinemode.emplace_back(NLMode::H_SWISH);
  50. }
  51. for (size_t n : {1, 2}) {
  52. for (auto nlmode : nonlinemode) {
  53. for (size_t ic : {1, 3, 7}) {
  54. for (size_t oc : {1, 3, 7}) {
  55. for (size_t size : {4, 6, 8, 14, 16, 18}) {
  56. for (size_t kern : kernel) {
  57. pack(n, oc, ic, size, size, kern, stride, nlmode);
  58. }
  59. }
  60. }
  61. }
  62. }
  63. }
  64. return args;
  65. }
  66. std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
  67. std::vector<size_t> kernel_vec, size_t stride, bool no_pad = false,
  68. bool no_bias = false, bool no_nonlinemode = false,
  69. bool is_input_nchw = false) {
  70. using namespace conv_bias;
  71. using NLMode = param::ConvBias::NonlineMode;
  72. std::vector<TestArg> args;
  73. auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w,
  74. size_t kernel, size_t stride, size_t group, NLMode nlmode) {
  75. constexpr int pack_c = 4;
  76. const size_t pad = no_pad ? 0 : kernel / 2;
  77. auto bias_mode = no_bias ? megdnn::BiasMode::NO_BIAS
  78. : megdnn::BiasMode::BROADCAST_CHANNEL_BIAS;
  79. auto oc_per_group = oc / group;
  80. auto ic_per_group = ic / group;
  81. bool ok_group = (oc % group == 0 && ic % group == 0) &&
  82. oc_per_group % pack_c == 0 && oc_per_group > 0 &&
  83. ic_per_group > 0;
  84. bool nchw_disable = group > 1 || ic_per_group >= 4;
  85. bool nchw44_disable = ic_per_group % pack_c != 0;
  86. if (!(ok_group)) {
  87. return;
  88. }
  89. if ((is_input_nchw && nchw_disable) ||
  90. (!is_input_nchw && nchw44_disable)) {
  91. return;
  92. }
  93. size_t kernel_h = kernel;
  94. size_t kernel_w = kernel;
  95. param::ConvBias param;
  96. param.format = param::ConvBias::Format::NCHW44;
  97. param.stride_h = stride;
  98. param.stride_w = stride;
  99. param.pad_h = pad;
  100. param.pad_w = pad;
  101. param.nonlineMode = nlmode;
  102. auto src_tensor_shape = TensorShape{n, ic / pack_c, h, w, pack_c};
  103. auto weight_tensor_shape = TensorShape{
  104. oc / pack_c, ic / pack_c, kernel_h, kernel_w, pack_c, pack_c};
  105. auto bias_tensor_shape = TensorShape{};
  106. if (bias_mode == megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) {
  107. bias_tensor_shape = {1, oc / pack_c, 1, 1, pack_c};
  108. }
  109. if (group == 1) {
  110. param.sparse = param::ConvBias::Sparse::DENSE;
  111. } else if (group > 1 && ic / group == 1 && oc / group == 1) {
  112. megdnn_assert(0, "not support channel wise");
  113. param.sparse = param::ConvBias::Sparse::GROUP;
  114. weight_tensor_shape = TensorShape{group / pack_c, 1, 1,
  115. kernel_h, kernel_w, pack_c};
  116. } else if (group > 1 && oc_per_group % pack_c == 0 && oc / group > 0 &&
  117. ic_per_group % pack_c == 0 && ic / group > 0) {
  118. param.sparse = param::ConvBias::Sparse::GROUP;
  119. weight_tensor_shape = TensorShape{group,
  120. oc_per_group / pack_c,
  121. ic_per_group / pack_c,
  122. kernel_h,
  123. kernel_w,
  124. pack_c,
  125. pack_c};
  126. }
  127. if (is_input_nchw) {
  128. src_tensor_shape = TensorShape{n, ic, h, w};
  129. weight_tensor_shape =
  130. TensorShape{oc / pack_c, kernel_h, kernel_w, ic, pack_c};
  131. }
  132. args.emplace_back(param, src_tensor_shape, weight_tensor_shape,
  133. bias_tensor_shape);
  134. };
  135. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  136. if (!no_nonlinemode) {
  137. nonlinemode.emplace_back(NLMode::RELU);
  138. nonlinemode.emplace_back(NLMode::H_SWISH);
  139. }
  140. for (auto nlmode : nonlinemode)
  141. for (size_t n : {1, 2})
  142. for (size_t kernel : kernel_vec)
  143. for (size_t oc : {4, 12, 32})
  144. for (size_t ic : {1, 3, 4, 12, 32})
  145. for (size_t h : {3, 5, 12})
  146. for (size_t w : {7, 16, 23}) {
  147. for (size_t group = 1;
  148. group <= std::min(oc, ic); ++group) {
  149. pack(n, oc, ic, h, w, kernel, stride, group,
  150. nlmode);
  151. }
  152. }
  153. return args;
  154. }
  155. std::vector<conv_bias::TestArg> get_int8_quint8_nchw44_channel_wise_args(
  156. std::vector<size_t> kernel, size_t stride, bool no_bias,
  157. bool no_nonlinemode) {
  158. using namespace conv_bias;
  159. using Param = param::ConvBias;
  160. using NLMode = param::ConvBias::NonlineMode;
  161. std::vector<TestArg> args;
  162. auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
  163. size_t stride, NLMode nlmode, bool pad) {
  164. Param param;
  165. param.stride_h = stride;
  166. param.stride_w = stride;
  167. if (pad) {
  168. param.pad_h = kernel / 2;
  169. param.pad_w = kernel / 2;
  170. } else {
  171. param.pad_h = 0;
  172. param.pad_w = 0;
  173. }
  174. param.nonlineMode = nlmode;
  175. param.format = param::ConvBias::Format::NCHW44;
  176. param.sparse = param::ConvBias::Sparse::GROUP;
  177. args.emplace_back(param, TensorShape{n, group, h, w, 4},
  178. TensorShape{group, 1, 1, kernel, kernel, 4},
  179. TensorShape{});
  180. if (!no_bias) {
  181. args.emplace_back(param, TensorShape{n, group, h, w, 4},
  182. TensorShape{group, 1, 1, kernel, kernel, 4},
  183. TensorShape{1, group, 1, 1, 4});
  184. }
  185. };
  186. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  187. if (!no_nonlinemode) {
  188. nonlinemode.emplace_back(NLMode::RELU);
  189. nonlinemode.emplace_back(NLMode::H_SWISH);
  190. }
  191. for (size_t n : {1, 2}) {
  192. for (auto nlmode : nonlinemode) {
  193. for (bool pad : {true}) {
  194. for (size_t group : {1, 2, 4, 7, 128}) {
  195. for (size_t size : {4, 5, 6, 7, 8, 9, 10, 15, 40}) {
  196. for (size_t kern : kernel) {
  197. pack(n, group, size, size, kern, stride, nlmode,
  198. pad);
  199. }
  200. }
  201. }
  202. }
  203. for (bool pad : {false}) {
  204. for (size_t group : {1, 2, 7, 128}) {
  205. for (size_t size : {7, 8, 9, 10, 15, 40}) {
  206. for (size_t kern : kernel) {
  207. pack(n, group, size, size, kern, stride, nlmode,
  208. pad);
  209. }
  210. }
  211. }
  212. }
  213. }
  214. }
  215. return args;
  216. }
  217. void checker_conv_bias_qint8x8x8(std::vector<conv_bias::TestArg> args,
  218. Handle* handle, const char* algo_name) {
  219. Checker<ConvBias> checker(handle);
  220. checker.set_before_exec_callback(
  221. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  222. #if MEGDNN_ARMV7
  223. checker.set_epsilon(1);
  224. #endif
  225. UniformIntRNG rng{-50, 50};
  226. checker.set_dtype(0, dtype::QuantizedS8(0.41113496f))
  227. .set_dtype(1, dtype::QuantizedS8(0.01887994f))
  228. .set_dtype(2, dtype::QuantizedS32(0.41113496f * 0.01887994f))
  229. .set_dtype(4, dtype::QuantizedS8(0.49550694f))
  230. .set_rng(0, &rng)
  231. .set_rng(1, &rng)
  232. .set_rng(2, &rng);
  233. for (auto&& arg : args) {
  234. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  235. }
  236. }
  237. void checker_conv_bias_qint8x8x32(std::vector<conv_bias::TestArg> args,
  238. Handle* handle, const char* algo_name) {
  239. Checker<ConvBias> checker(handle);
  240. UniformIntRNG rng{-50, 50};
  241. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  242. .set_dtype(1, dtype::QuantizedS8(2.5f))
  243. .set_dtype(2, dtype::QuantizedS32(6.25f))
  244. .set_dtype(4, {});
  245. checker.set_before_exec_callback(
  246. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  247. for (auto&& arg : args) {
  248. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  249. }
  250. }
  251. void checker_conv_bias_quint8x8x8(std::vector<conv_bias::TestArg> args,
  252. Handle* handle, const char* algo_name) {
  253. Checker<ConvBias> checker(handle);
  254. checker.set_before_exec_callback(
  255. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  256. UniformIntRNG rng(0, 255);
  257. checker.set_dtype(0, dtype::Quantized8Asymm(0.2f, 100))
  258. .set_dtype(1, dtype::Quantized8Asymm(0.2f, 120))
  259. .set_dtype(2, dtype::QuantizedS32(0.04f))
  260. .set_dtype(4, dtype::Quantized8Asymm(1.4f, 110))
  261. .set_rng(0, &rng)
  262. .set_rng(1, &rng)
  263. .set_rng(2, &rng);
  264. for (auto&& arg : args) {
  265. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  266. }
  267. }
  268. void checker_conv_bias_quint8x8x32(std::vector<conv_bias::TestArg> args,
  269. Handle* handle, const char* algo_name) {
  270. Checker<ConvBias> checker(handle);
  271. checker.set_before_exec_callback(
  272. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  273. NormalRNG rng(128.f);
  274. checker.set_rng(0, &rng).set_rng(1, &rng);
  275. checker.set_dtype(0, dtype::Quantized8Asymm(1.2f, (uint8_t)127))
  276. .set_dtype(1, dtype::Quantized8Asymm(1.3f, (uint8_t)129))
  277. .set_dtype(2, dtype::QuantizedS32(1.2 * 1.3))
  278. .set_dtype(4, {});
  279. for (auto&& arg : args) {
  280. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  281. }
  282. }
  283. void checker_conv_bias_int8x8x32_multi(std::vector<conv_bias::TestArg> args,
  284. Handle* handle, const char* algo_name) {
  285. Checker<ConvBias> checker(handle);
  286. checker.set_before_exec_callback(
  287. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  288. checker.set_dtype(0, dtype::Int8());
  289. checker.set_dtype(1, dtype::Int8());
  290. checker.set_dtype(2, dtype::Int32());
  291. checker.set_dtype(4, dtype::Int32());
  292. for (auto&& arg : args) {
  293. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  294. }
  295. }
  296. /**********************************F32 direct************************/
  297. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_LARGE_GROUP) {
  298. check_conv_bias(
  299. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  300. handle(), "F32DIRECT_LARGE_GROUP");
  301. }
  302. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_SMALL_GROUP) {
  303. check_conv_bias(
  304. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  305. handle(), "F32DIRECT_SMALL_GROUP");
  306. }
  307. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_LARGE_GROUP) {
  308. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  309. handle(), "F32STRD1_LARGE_GROUP");
  310. }
  311. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_SMALL_GROUP) {
  312. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  313. handle(), "F32STRD1_SMALL_GROUP");
  314. }
  315. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_LARGE_GROUP) {
  316. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  317. handle(), "F32STRD2_LARGE_GROUP");
  318. }
  319. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) {
  320. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  321. handle(), "F32STRD2_SMALL_GROUP");
  322. }
  323. /**********************************F16 direct************************/
  324. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  325. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_LARGE_GROUP) {
  326. NormalRNG rng(1);
  327. checker_conv_bias_f16(
  328. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  329. handle(), rng, "F16DIRECT_LARGE_GROUP", 0.03);
  330. }
  331. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_SMALL_GROUP) {
  332. NormalRNG rng(1);
  333. checker_conv_bias_f16(
  334. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  335. handle(), rng, "F16DIRECT_SMALL_GROUP", 0.03);
  336. }
  337. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_LARGE_GROUP) {
  338. NormalRNG rng(1);
  339. checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false),
  340. handle(), rng, "F16STRD1_LARGE_GROUP", 0.03);
  341. }
  342. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_SMALL_GROUP) {
  343. NormalRNG rng(1);
  344. checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false),
  345. handle(), rng, "F16STRD1_SMALL_GROUP", 0.03);
  346. }
  347. #endif
  348. /**********************************algo 8816 direct************************/
  349. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_LARGE_GROUP) {
  350. checker_conv_bias_int8x8x16(
  351. get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(),
  352. "I8816DIRECT_LARGE_GROUP");
  353. }
  354. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_SMALL_GROUP) {
  355. checker_conv_bias_int8x8x16(
  356. get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(),
  357. "I8816DIRECT_SMALL_GROUP");
  358. }
  359. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_LARGE_GROUP) {
  360. checker_conv_bias_int8x8x16(
  361. get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
  362. "I8816STRD2_LARGE_GROUP");
  363. }
  364. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_SMALL_GROUP) {
  365. checker_conv_bias_int8x8x16(
  366. get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
  367. "I8816STRD2_SMALL_GROUP");
  368. }
  369. /**********************************algo 8-8-32 direct************************/
  370. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_LARGE_GROUP) {
  371. checker_conv_bias_int8x8x32_multi(
  372. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  373. "S8STRD1_LARGE_GROUP");
  374. }
  375. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_SMALL_GROUP) {
  376. checker_conv_bias_int8x8x32_multi(
  377. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  378. "S8STRD1_SMALL_GROUP");
  379. }
  380. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_LARGE_GROUP) {
  381. checker_conv_bias_int8x8x32_multi(
  382. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  383. "S8STRD2_LARGE_GROUP");
  384. }
  385. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_SMALL_GROUP) {
  386. checker_conv_bias_int8x8x32_multi(
  387. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  388. "S8STRD2_SMALL_GROUP");
  389. }
  390. TEST_F(ARM_COMMON_MULTI_THREADS,
  391. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT1_NCHW44) {
  392. checker_conv_bias_int8x8x32_multi(
  393. get_int8_quint8_nchw44_channel_wise_args({2, 3, 5}, 1, false, true),
  394. handle(), "S8_CHAN_WISE_STRD1_NCHW44");
  395. }
  396. TEST_F(ARM_COMMON_MULTI_THREADS,
  397. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT2_NCHW44) {
  398. checker_conv_bias_int8x8x32_multi(
  399. get_int8_quint8_nchw44_channel_wise_args({2, 3, 5}, 2, false, true),
  400. handle(), "S8_CHAN_WISE_STRD2_NCHW44");
  401. }
  402. /********************************qint8 direct******************************/
  403. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_LARGE_GROUP) {
  404. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  405. {2, 3, 5, 7}, 1, false, false, false),
  406. handle(), "S8STRD1_LARGE_GROUP");
  407. }
  408. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_SMALL_GROUP) {
  409. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  410. {2, 3, 5, 7}, 1, false, false, false),
  411. handle(), "S8STRD1_SMALL_GROUP");
  412. }
  413. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_LARGE_GROUP) {
  414. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  415. {2, 3, 5, 7}, 2, false, false, false),
  416. handle(), "S8STRD2_LARGE_GROUP");
  417. }
  418. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_SMALL_GROUP) {
  419. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  420. {2, 3, 5, 7}, 2, false, false, false),
  421. handle(), "S8STRD2_SMALL_GROUP");
  422. }
  423. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) {
  424. checker_conv_bias_qint8x8x8(
  425. get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  426. handle(), "S8_NCHW44_DIRECT_STRD1");
  427. }
  428. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) {
  429. checker_conv_bias_qint8x8x8(
  430. get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  431. handle(), "S8_NCHW44_DIRECT_STRD2");
  432. }
  433. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) {
  434. checker_conv_bias_qint8x8x8(get_int8_quint8_nchw44_channel_wise_args(
  435. {2, 3, 5}, 1, false, false),
  436. handle(), "S8_CHAN_WISE_STRD1_NCHW44");
  437. }
  438. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) {
  439. checker_conv_bias_qint8x8x8(get_int8_quint8_nchw44_channel_wise_args(
  440. {2, 3, 5}, 2, false, false),
  441. handle(), "S8_CHAN_WISE_STRD2_NCHW44");
  442. }
  443. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44) {
  444. checker_conv_bias_qint8x8x8(
  445. get_nchw44_conv_bias_args({3, 5, 7}, 2, false, false, false, true),
  446. handle(), "S8_CONV_NCHW_NCHW44");
  447. }
  448. /*****************************quint8 direct****************************/
  449. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_LARGE_GROUP) {
  450. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  451. {2, 3, 5, 7}, 1, false, false, false),
  452. handle(), "QU8STRD1_LARGE_GROUP");
  453. }
  454. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_SMALL_GROUP) {
  455. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  456. {2, 3, 5, 7}, 1, false, false, false),
  457. handle(), "QU8STRD1_SMALL_GROUP");
  458. }
  459. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_LARGE_GROUP) {
  460. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  461. {2, 3, 5, 7}, 2, false, false, false),
  462. handle(), "QU8STRD2_LARGE_GROUP");
  463. }
  464. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) {
  465. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  466. {2, 3, 5, 7}, 2, false, false, false),
  467. handle(), "QU8STRD2_SMALL_GROUP");
  468. }
  469. /****************************dot qint8 direct*************************/
  470. #if __ARM_FEATURE_DOTPROD
  471. TEST_F(ARM_COMMON_MULTI_THREADS,
  472. CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) {
  473. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  474. {2, 3, 5, 7}, 1, false, false, false),
  475. handle(), "ARMDOTS8STRD1_LARGE_GROUP");
  476. }
  477. TEST_F(ARM_COMMON_MULTI_THREADS,
  478. CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) {
  479. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  480. {2, 3, 5, 7}, 1, false, false, false),
  481. handle(), "ARMDOTS8STRD1_SMALL_GROUP");
  482. }
  483. TEST_F(ARM_COMMON_MULTI_THREADS,
  484. CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) {
  485. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  486. {2, 3, 5, 7}, 2, false, false, false),
  487. handle(), "ARMDOTS8STRD2_LARGE_GROUP");
  488. }
  489. TEST_F(ARM_COMMON_MULTI_THREADS,
  490. CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) {
  491. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  492. {2, 3, 5, 7}, 2, false, false, false),
  493. handle(), "ARMDOTS8STRD2_SMALL_GROUP");
  494. }
  495. /****************************dot 8-8-32 direct*************************/
  496. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_LARGE_GROUP) {
  497. checker_conv_bias_qint8x8x32(
  498. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  499. "ARMDOTS8STRD1_LARGE_GROUP");
  500. }
  501. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_SMALL_GROUP) {
  502. checker_conv_bias_qint8x8x32(
  503. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  504. "ARMDOTS8STRD1_SMALL_GROUP");
  505. }
  506. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_LARGE_GROUP) {
  507. checker_conv_bias_qint8x8x32(
  508. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  509. "ARMDOTS8STRD2_LARGE_GROUP");
  510. }
  511. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_SMALL_GROUP) {
  512. checker_conv_bias_qint8x8x32(
  513. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  514. "ARMDOTS8STRD2_SMALL_GROUP");
  515. }
  516. /******************************dot quint8*****************************/
  517. TEST_F(ARM_COMMON_MULTI_THREADS,
  518. CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) {
  519. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  520. {2, 3, 5, 7}, 1, false, false, false),
  521. handle(), "ARMDOTU8STRD1_LARGE_GROUP");
  522. }
  523. TEST_F(ARM_COMMON_MULTI_THREADS,
  524. CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) {
  525. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  526. {2, 3, 5, 7}, 1, false, false, false),
  527. handle(), "ARMDOTU8STRD1_SMALL_GROUP");
  528. }
  529. TEST_F(ARM_COMMON_MULTI_THREADS,
  530. CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) {
  531. checker_conv_bias_quint8x8x8(
  532. get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false),
  533. handle(), "ARMDOTU8STRD2_LARGE_GROUP");
  534. }
  535. TEST_F(ARM_COMMON_MULTI_THREADS,
  536. CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) {
  537. checker_conv_bias_quint8x8x8(
  538. get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false),
  539. handle(), "ARMDOTU8STRD2_SMALL_GROUP");
  540. }
  541. /******************************dot quint8x8x32***********************/
  542. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_LARGE_GROUP) {
  543. checker_conv_bias_quint8x8x32(
  544. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  545. "ARMDOTU8STRD1_LARGE_GROUP");
  546. }
  547. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_SMALL_GROUP) {
  548. checker_conv_bias_quint8x8x32(
  549. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  550. "ARMDOTU8STRD1_SMALL_GROUP");
  551. }
  552. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_LARGE_GROUP) {
  553. checker_conv_bias_quint8x8x32(
  554. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  555. "ARMDOTU8STRD2_LARGE_GROUP");
  556. }
  557. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_SMALL_GROUP) {
  558. checker_conv_bias_quint8x8x32(
  559. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  560. "ARMDOTU8STRD2_SMALL_GROUP");
  561. }
  562. #endif
  563. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4) {
  564. using namespace conv_bias;
  565. std::vector<TestArg> args = get_winograd_mk_packed_args();
  566. Checker<ConvBiasForward> checker(handle());
  567. check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4);
  568. }
  569. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63) {
  570. using namespace conv_bias;
  571. std::vector<TestArg> args = get_winograd_args(3);
  572. Checker<ConvBiasForward> checker(handle());
  573. check_winograd("1:6:32", checker, args);
  574. }
  575. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) {
  576. using namespace conv_bias;
  577. std::vector<TestArg> args = get_winograd_mk_packed_args();
  578. Checker<ConvBiasForward> checker(handle());
  579. check_winograd("4:6:32", checker, args, param::MatrixMul::Format::MK4);
  580. }
  581. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54) {
  582. using namespace conv_bias;
  583. std::vector<TestArg> args = get_winograd_args(4);
  584. Checker<ConvBiasForward> checker(handle());
  585. check_winograd("1:5:32", checker, args);
  586. }
  587. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45) {
  588. using namespace conv_bias;
  589. std::vector<TestArg> args = get_winograd_args(5);
  590. Checker<ConvBiasForward> checker(handle());
  591. check_winograd("1:4:32", checker, args);
  592. }
  593. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) {
  594. using namespace conv_bias;
  595. std::vector<TestArg> args = get_winograd_args(3);
  596. Checker<ConvBiasForward> checker(handle());
  597. auto extra_impl = [](const TensorNDArray& tensors, uint32_t m,
  598. param::ConvBias param, Handle* handle) {
  599. megdnn_assert(param.format == param::ConvBias::Format::NCHW);
  600. auto winograd_preprocess_opr =
  601. handle->create_operator<WinogradFilterPreprocess>();
  602. winograd_preprocess_opr->param().output_block_size = m;
  603. TensorLayout filter_transform_layout;
  604. winograd_preprocess_opr->deduce_layout(tensors[1].layout,
  605. filter_transform_layout);
  606. size_t winograd_preprocess_workspace_in_bytes =
  607. winograd_preprocess_opr->get_workspace_in_bytes(
  608. tensors[1].layout, filter_transform_layout);
  609. auto conv_bias_opr = handle->create_operator<ConvBias>();
  610. conv_bias_opr->param() = param;
  611. conv_bias_opr->param().format = param::ConvBias::Format::NCHW_WINOGRAD;
  612. conv_bias_opr->param().output_block_size = m;
  613. size_t conv_bias_workspace_in_bytes =
  614. conv_bias_opr->get_workspace_in_bytes(
  615. tensors[0].layout, filter_transform_layout,
  616. tensors[2].layout, tensors[3].layout,
  617. tensors[4].layout);
  618. WorkspaceBundle wb(nullptr, {filter_transform_layout.span().dist_byte(),
  619. conv_bias_workspace_in_bytes,
  620. winograd_preprocess_workspace_in_bytes});
  621. wb.set(malloc(wb.total_size_in_bytes()));
  622. TensorND filter_transform_tensor(wb.get(0),
  623. std::move(filter_transform_layout));
  624. winograd_preprocess_opr->exec(tensors[1], filter_transform_tensor,
  625. wb.get_workspace(2));
  626. conv_bias_opr->exec(tensors[0], filter_transform_tensor, tensors[2],
  627. tensors[3], tensors[4], wb.get_workspace(1));
  628. free(wb.ptr());
  629. };
  630. auto run = [&checker, &extra_impl](
  631. Handle* handle, const std::vector<TestArg>& args,
  632. const std::vector<size_t>& out_size, DType A_dtype,
  633. DType B_dtype, DType C_dtype, DType D_dtype,
  634. const float eps) {
  635. for (auto&& arg : args) {
  636. for (uint32_t m : out_size) {
  637. checker.set_extra_opr_impl(std::bind(extra_impl,
  638. std::placeholders::_1, m,
  639. arg.param, handle));
  640. checker.set_dtype(0, A_dtype)
  641. .set_dtype(1, B_dtype)
  642. .set_dtype(2, C_dtype)
  643. .set_dtype(4, D_dtype)
  644. .set_epsilon(eps)
  645. .set_param(arg.param)
  646. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  647. }
  648. }
  649. };
  650. run(handle(), args, {6}, dtype::Float32(), dtype::Float32(),
  651. dtype::Float32(), dtype::Float32(), 1e-3f);
  652. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  653. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  654. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  655. run(handle(), args, {6}, dtype::Float16(), dtype::Float16(),
  656. dtype::Float16(), dtype::Float16(), 0.35f);
  657. #endif
  658. }
  659. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1) {
  660. using namespace conv_bias;
  661. Checker<ConvBiasForward> checker(handle());
  662. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  663. const std::vector<size_t>& out_size, DType A_dtype,
  664. DType B_dtype, DType C_dtype, DType D_dtype,
  665. param::MatrixMul::Format format, float eps) {
  666. for (auto&& arg : args) {
  667. for (uint32_t m : out_size) {
  668. checker.set_extra_opr_impl(std::bind(
  669. winograd_algo_extra_impl, std::placeholders::_1, m,
  670. arg.param, handle, format));
  671. checker.set_dtype(0, A_dtype)
  672. .set_dtype(1, B_dtype)
  673. .set_dtype(2, C_dtype)
  674. .set_dtype(4, D_dtype)
  675. .set_epsilon(eps)
  676. .set_param(arg.param)
  677. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  678. }
  679. }
  680. };
  681. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  682. std::vector<TestArg> args_first_half(args.begin(),
  683. args.begin() + args.size() / 2);
  684. run(handle(), args_first_half, {2, 6}, dtype::Float32{}, dtype::Float32{},
  685. dtype::Float32{}, dtype::Float32{}, param::MatrixMul::Format::MK4,
  686. 1e-3f);
  687. }
  688. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2) {
  689. using namespace conv_bias;
  690. Checker<ConvBiasForward> checker(handle());
  691. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  692. const std::vector<size_t>& out_size, DType A_dtype,
  693. DType B_dtype, DType C_dtype, DType D_dtype,
  694. param::MatrixMul::Format format, float eps) {
  695. for (auto&& arg : args) {
  696. for (uint32_t m : out_size) {
  697. checker.set_extra_opr_impl(std::bind(
  698. winograd_algo_extra_impl, std::placeholders::_1, m,
  699. arg.param, handle, format));
  700. checker.set_dtype(0, A_dtype)
  701. .set_dtype(1, B_dtype)
  702. .set_dtype(2, C_dtype)
  703. .set_dtype(4, D_dtype)
  704. .set_epsilon(eps)
  705. .set_param(arg.param)
  706. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  707. }
  708. }
  709. };
  710. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  711. std::vector<TestArg> args_second_half(args.begin() + args.size() / 2,
  712. args.end());
  713. run(handle(), args_second_half, {2, 6}, dtype::Float32{}, dtype::Float32{},
  714. dtype::Float32{}, dtype::Float32{}, param::MatrixMul::Format::MK4,
  715. 1e-3f);
  716. }
  717. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  718. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F16) {
  719. using namespace conv_bias;
  720. Checker<ConvBiasForward> checker(handle());
  721. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  722. const std::vector<size_t>& out_size, DType A_dtype,
  723. DType B_dtype, DType C_dtype, DType D_dtype,
  724. param::MatrixMul::Format format, float eps) {
  725. for (auto&& arg : args) {
  726. for (uint32_t m : out_size) {
  727. checker.set_extra_opr_impl(std::bind(
  728. winograd_algo_extra_impl, std::placeholders::_1, m,
  729. arg.param, handle, format));
  730. checker.set_dtype(0, A_dtype)
  731. .set_dtype(1, B_dtype)
  732. .set_dtype(2, C_dtype)
  733. .set_dtype(4, D_dtype)
  734. .set_epsilon(eps)
  735. .set_param(arg.param)
  736. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  737. }
  738. }
  739. };
  740. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  741. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  742. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  743. run(handle(), args, {2}, dtype::Float16{}, dtype::Float16{},
  744. dtype::Float16{}, dtype::Float16{}, param::MatrixMul::Format::MK8,
  745. 0.25);
  746. }
  747. #endif
  748. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_INT8) {
  749. using namespace conv_bias;
  750. Checker<ConvBiasForward> checker(handle());
  751. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  752. const std::vector<size_t>& out_size, DType A_dtype,
  753. DType B_dtype, DType C_dtype, DType D_dtype,
  754. param::MatrixMul::Format format, float eps) {
  755. for (auto&& arg : args) {
  756. for (uint32_t m : out_size) {
  757. checker.set_extra_opr_impl(std::bind(
  758. winograd_algo_extra_impl, std::placeholders::_1, m,
  759. arg.param, handle, format));
  760. checker.set_dtype(0, A_dtype)
  761. .set_dtype(1, B_dtype)
  762. .set_dtype(2, C_dtype)
  763. .set_dtype(4, D_dtype)
  764. .set_epsilon(eps)
  765. .set_param(arg.param)
  766. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  767. }
  768. }
  769. };
  770. #if MEGDNN_AARCH64
  771. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  772. #else
  773. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  774. #endif
  775. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  776. ssprintf("WINOGRAD:%s:8:2:32", matmul_name).c_str()));
  777. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  778. std::vector<TestArg> quantized_args =
  779. get_quantized_winograd_mk_packed_args(8);
  780. UniformIntRNG int_rng{-50, 50};
  781. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  782. run(handle(), quantized_args, {2}, dtype::QuantizedS8(2.5f),
  783. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f),
  784. dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3);
  785. }
  786. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  787. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F23) {
  788. using namespace conv_bias;
  789. std::vector<TestArg> args = get_winograd_mk_packed_args();
  790. Checker<ConvBiasForward> checker(handle());
  791. check_winograd_fp16("1:2:32", checker, args, NULL, 0.08);
  792. }
  793. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_1) {
  794. using namespace conv_bias;
  795. std::vector<TestArg> args = get_winograd_args(5);
  796. std::vector<TestArg> args_head_half(args.begin(),
  797. args.begin() + args.size() / 2);
  798. Checker<ConvBiasForward> checker(handle());
  799. //! fp16 range -1.0 ~ 1.0
  800. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  801. check_winograd_fp16("1:4:32", checker, args_head_half, rng, 0.25);
  802. }
  803. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_2) {
  804. using namespace conv_bias;
  805. std::vector<TestArg> args = get_winograd_args(5);
  806. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2,
  807. args.end());
  808. Checker<ConvBiasForward> checker(handle());
  809. //! fp16 range -1.0 ~ 1.0
  810. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  811. check_winograd_fp16("1:4:32", checker, args_back_half, rng, 0.25);
  812. }
  813. //! FIXME: This test may be failed if run `ARM_COMMON.CONV_BIAS_WINOGRAD*`, but
  814. //! it will pass when run single testcase
  815. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F63) {
  816. using namespace conv_bias;
  817. std::vector<TestArg> args = get_winograd_args(3);
  818. Checker<ConvBiasForward> checker(handle());
  819. //! fp16 range -1.0 ~ 1.0
  820. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  821. check_winograd_fp16("1:6:32", checker, args, rng, 0.3);
  822. }
  823. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_1) {
  824. using namespace conv_bias;
  825. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  826. std::vector<TestArg> args_head_half(args.begin(),
  827. args.begin() + args.size() / 2);
  828. Checker<ConvBiasForward> checker(handle());
  829. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  830. check_winograd_fp16("8:2:32", checker, args_head_half, rng, 0.25,
  831. param::MatrixMul::Format::MK8);
  832. }
  833. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_2) {
  834. using namespace conv_bias;
  835. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  836. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2,
  837. args.end());
  838. Checker<ConvBiasForward> checker(handle());
  839. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  840. check_winograd_fp16("8:2:32", checker, args_back_half, rng, 0.25,
  841. param::MatrixMul::Format::MK8);
  842. }
  843. #endif
  844. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_INT8_8X8) {
  845. using namespace conv_bias;
  846. std::vector<TestArg> args = get_quantized_winograd_mk_packed_args(8);
  847. Checker<ConvBiasForward> checker(handle());
  848. UniformIntRNG rng{-50, 50};
  849. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  850. .set_dtype(1, dtype::QuantizedS8(2.5f))
  851. .set_dtype(2, dtype::QuantizedS32(6.25f))
  852. .set_dtype(4, dtype::QuantizedS8(60.25f))
  853. .set_rng(0, &rng)
  854. .set_rng(1, &rng)
  855. .set_rng(2, &rng);
  856. check_winograd("8:2:32", checker, args, param::MatrixMul::Format::MK8);
  857. }
  858. void checker_conv_bias(std::vector<conv_bias::TestArg> args, Handle* handle,
  859. RNG* rng, float epsilon, DType type0, DType type1,
  860. DType type2, DType type3, const char* algo_name) {
  861. using namespace conv_bias;
  862. Checker<ConvBias> checker(handle);
  863. checker.set_before_exec_callback(
  864. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  865. checker.set_dtype(0, type0);
  866. checker.set_dtype(1, type1);
  867. checker.set_dtype(2, type2);
  868. checker.set_dtype(4, type3);
  869. checker.set_epsilon(epsilon);
  870. if (NULL != rng) {
  871. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng).set_rng(3, rng);
  872. }
  873. for (auto&& arg : args) {
  874. checker.set_param(arg.param).execs(
  875. {arg.src, arg.filter, arg.bias, {}, {}});
  876. }
  877. }
  878. // clang-format off
  879. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE2) {
  880. #define cb(name) \
  881. check_conv_bias( \
  882. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 2, false, false, false), \
  883. handle(), name);
  884. #if MEGDNN_AARCH64
  885. cb("IM2COLMATMUL:AARCH64_F32K8X12X1")
  886. cb("IM2COLMATMUL:AARCH64_F32K4X16X1")
  887. cb("IM2COLMATMUL:FB_F32_K8X12X1")
  888. #elif MEGDNN_ARMV7
  889. cb("IM2COLMATMUL:ARMV7_F32")
  890. #endif
  891. #undef cb
  892. }
  893. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE1) {
  894. #define cb(name) \
  895. check_conv_bias( \
  896. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, false), \
  897. handle(), name);
  898. #if MEGDNN_AARCH64
  899. cb("IM2COLMATMUL:AARCH64_F32K8X12X1")
  900. cb("IM2COLMATMUL:AARCH64_F32K4X16X1")
  901. cb("IM2COLMATMUL:FB_F32_K8X12X1")
  902. #elif MEGDNN_ARMV7
  903. cb("IM2COLMATMUL:ARMV7_F32")
  904. cb("IM2COLMATMUL:FB_F32_K8X12X1")
  905. #endif
  906. #undef cb
  907. }
  908. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) {
  909. UniformIntRNG rng{-50, 50};
  910. #define cb(name) \
  911. checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
  912. false, true, true), \
  913. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  914. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  915. dtype::QuantizedS8(60.25f), name); \
  916. checker_conv_bias( \
  917. get_conv_bias_args({1}, 2, false, false, false, true, true), \
  918. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  919. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  920. dtype::QuantizedS8(60.25f), name);
  921. float epsilon = 0.001;
  922. #if MEGDNN_AARCH64
  923. #if __ARM_FEATURE_DOTPROD
  924. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD");
  925. #else
  926. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8");
  927. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16");
  928. #endif
  929. #elif MEGDNN_ARMV7
  930. epsilon = 1;
  931. cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8");
  932. #endif
  933. #undef cb
  934. }
  935. // clang-format on
  936. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  937. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) {
  938. NormalRNG rng(128.f);
  939. #define cb(name) \
  940. checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
  941. false, true, true), \
  942. handle(), &rng, epsilon, \
  943. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  944. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  945. dtype::QuantizedS32(1.2 * 1.3), \
  946. dtype::Quantized8Asymm(50.3f, (uint8_t)120), name); \
  947. checker_conv_bias( \
  948. get_conv_bias_args({1}, 2, false, false, false, true, true), \
  949. handle(), &rng, epsilon, \
  950. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  951. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  952. dtype::QuantizedS32(1.2 * 1.3), \
  953. dtype::Quantized8Asymm(50.3f, (uint8_t)120), name);
  954. float epsilon = 0.001;
  955. #if MEGDNN_AARCH64
  956. #if __ARM_FEATURE_DOTPROD
  957. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD");
  958. #else
  959. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8");
  960. #endif
  961. #elif MEGDNN_ARMV7
  962. epsilon = 1;
  963. cb("IM2COLMATMUL:ARMV7_QUINT8_K4X8X8");
  964. #endif
  965. #undef cb
  966. }
  967. #endif
  968. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  969. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) {
  970. UniformIntRNG rng{-50, 50};
  971. float epsilon = 0.001;
  972. #define cb(name) \
  973. checker_conv_bias( \
  974. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \
  975. handle(), &rng, epsilon, \
  976. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  977. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  978. dtype::QuantizedS32(1.2 * 1.3), {}, name); \
  979. checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \
  980. &rng, epsilon, \
  981. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  982. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  983. dtype::QuantizedS32(1.2 * 1.3), {}, name);
  984. #if MEGDNN_AARCH64
  985. #if __ARM_FEATURE_DOTPROD
  986. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD");
  987. #else
  988. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8");
  989. #endif
  990. #elif MEGDNN_ARMV7
  991. #if __ARM_FEATURE_DOTPROD
  992. cb("IM2COLMATMUL:AARCH32_QUINT8_K4X8X4");
  993. #endif
  994. cb("IM2COLMATMUL:ARMV7_QUINT8_K4X8X8");
  995. #endif
  996. #undef cb
  997. }
  998. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) {
  999. UniformIntRNG rng{-50, 50};
  1000. float epsilon = 0.001;
  1001. #define cb(name) \
  1002. checker_conv_bias( \
  1003. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \
  1004. handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
  1005. dtype::Int16{}, dtype::Int16{}, name); \
  1006. checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \
  1007. &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
  1008. dtype::Int16{}, dtype::Int16{}, name);
  1009. #if MEGDNN_AARCH64
  1010. cb("IM2COLMATMUL:AARCH64_INT8X8X16_K8X8X8");
  1011. cb("IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16");
  1012. cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16");
  1013. #elif MEGDNN_ARMV7
  1014. cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16");
  1015. cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8");
  1016. cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X2X16");
  1017. #endif
  1018. #undef cb
  1019. }
  1020. #endif
  1021. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  1022. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP16) {
  1023. using namespace conv_bias;
  1024. param::ConvBias cur_param;
  1025. std::vector<conv_bias::TestArg> args =
  1026. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, false);
  1027. std::vector<conv_bias::TestArg> args1 =
  1028. get_conv_bias_args({1}, 2, false, false, false);
  1029. args.insert(args.begin(), args1.begin(), args1.end());
  1030. NormalRNG rng(1);
  1031. #define cb(name) \
  1032. checker_conv_bias(args, handle(), &rng, 0.03, dtype::Float16{}, \
  1033. dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, \
  1034. name);
  1035. #if MEGDNN_AARCH64
  1036. cb("IM2COLMATMUL:AARCH64_F16_K8X24X1");
  1037. #elif MEGDNN_ARMV7
  1038. cb("IM2COLMATMUL:AARCH32_F16_K4X16X1");
  1039. #endif
  1040. #undef cb
  1041. }
  1042. #endif
  1043. void checker_conv_bias_mul_int8x8x32(std::vector<conv_bias::TestArg> args,
  1044. Handle* handle, const char* algo_name) {
  1045. using namespace conv_bias;
  1046. Checker<ConvBias> checker(handle);
  1047. checker.set_before_exec_callback(
  1048. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  1049. checker.set_dtype(0, dtype::Int8());
  1050. checker.set_dtype(1, dtype::Int8());
  1051. checker.set_dtype(2, dtype::Int32());
  1052. checker.set_dtype(4, dtype::Int32());
  1053. for (auto&& arg : args) {
  1054. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  1055. }
  1056. UniformIntRNG rng{-50, 50};
  1057. for (auto&& arg : args) {
  1058. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  1059. .set_dtype(1, dtype::QuantizedS8(2.5f))
  1060. .set_dtype(2, dtype::QuantizedS32(6.25f))
  1061. .set_dtype(4, {})
  1062. .set_rng(0, &rng)
  1063. .set_rng(1, &rng)
  1064. .set_rng(2, &rng)
  1065. .set_param(arg.param)
  1066. .execs({arg.src, arg.filter, {}, {}, {}});
  1067. }
  1068. }
  1069. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1070. #if !__ARM_FEATURE_DOTPROD
  1071. TEST_F(ARM_COMMON, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44) {
  1072. using namespace conv_bias;
  1073. std::vector<conv_bias::TestArg> args =
  1074. get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true);
  1075. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1076. #if MEGDNN_AARCH64
  1077. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1078. #else
  1079. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1080. #endif
  1081. #undef cb
  1082. }
  1083. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_MULTI) {
  1084. using namespace conv_bias;
  1085. std::vector<conv_bias::TestArg> args =
  1086. get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true);
  1087. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1088. #if MEGDNN_AARCH64
  1089. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1090. #else
  1091. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1092. #endif
  1093. #undef cb
  1094. }
  1095. TEST_F(ARM_COMMON, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44) {
  1096. UniformIntRNG rng{-50, 50};
  1097. #define cb(name) \
  1098. checker_conv_bias(get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1), \
  1099. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1100. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1101. dtype::QuantizedS8(60.25f), name);
  1102. float epsilon = 0.001;
  1103. #if MEGDNN_AARCH64
  1104. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1105. #else
  1106. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1107. #endif
  1108. #undef cb
  1109. }
  1110. TEST_F(ARM_COMMON_MULTI_THREADS,
  1111. CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_MULTI) {
  1112. UniformIntRNG rng{-50, 50};
  1113. #define cb(name) \
  1114. checker_conv_bias(get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1), \
  1115. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1116. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1117. dtype::QuantizedS8(60.25f), name);
  1118. float epsilon = 0.001;
  1119. #if MEGDNN_AARCH64
  1120. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1121. #else
  1122. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1123. #endif
  1124. #undef cb
  1125. }
  1126. #endif
  1127. #endif
  1128. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
  1129. using namespace conv_bias;
  1130. std::vector<conv_bias::TestArg> args =
  1131. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true);
  1132. std::vector<conv_bias::TestArg> args1 =
  1133. get_conv_bias_args({1}, 2, false, true, true);
  1134. args.insert(args.begin(), args1.begin(), args1.end());
  1135. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1136. #if MEGDNN_AARCH64
  1137. #if __ARM_FEATURE_DOTPROD
  1138. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD");
  1139. #else
  1140. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8");
  1141. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16");
  1142. #endif
  1143. #elif MEGDNN_ARMV7
  1144. #if __ARM_FEATURE_DOTPROD
  1145. cb("IM2COLMATMUL:AARCH32_INT8_K6X8X4");
  1146. #endif
  1147. cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8");
  1148. #endif
  1149. #if MEGDNN_ARMV7
  1150. cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X2X16");
  1151. #endif
  1152. #undef cb
  1153. }
  1154. /***************************** Conv1x1 Algo Test ***********************/
  1155. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) {
  1156. using namespace conv_bias;
  1157. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, false);
  1158. #if MEGDNN_AARCH64
  1159. check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32K8X12X1:24");
  1160. #elif MEGDNN_ARMV7
  1161. check_conv_bias(args, handle(), "CONV1x1:ARMV7_F32:48");
  1162. #endif
  1163. }
  1164. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  1165. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16) {
  1166. using namespace conv_bias;
  1167. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, false);
  1168. NormalRNG rng(1);
  1169. #if MEGDNN_AARCH64
  1170. checker_conv_bias(args, handle(), &rng, 0.03, dtype::Float16{},
  1171. dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
  1172. "CONV1x1:AARCH64_F16_K8X24X1:48");
  1173. #elif MEGDNN_ARMV7
  1174. checker_conv_bias(args, handle(), &rng, 0.03, dtype::Float16{},
  1175. dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
  1176. "CONV1x1:AARCH32_F16_K4X16X1:24");
  1177. #endif
  1178. }
  1179. #endif
  1180. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM) {
  1181. UniformIntRNG rng{-50, 50};
  1182. float epsilon = 0.001;
  1183. #define cb(name) \
  1184. checker_conv_bias(get_conv_bias_1x1_args(false, false, true, true), \
  1185. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1186. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1187. dtype::QuantizedS8(60.25f), name);
  1188. #if MEGDNN_AARCH64
  1189. #if __ARM_FEATURE_DOTPROD
  1190. cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:24");
  1191. #else
  1192. cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24");
  1193. cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:48");
  1194. #endif
  1195. #elif MEGDNN_ARMV7
  1196. epsilon = 1;
  1197. cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:48");
  1198. #endif
  1199. #undef cb
  1200. }
  1201. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1202. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) {
  1203. NormalRNG rng(128.f);
  1204. #define cb(name) \
  1205. checker_conv_bias(get_conv_bias_1x1_args(false, false, true, true), \
  1206. handle(), &rng, epsilon, \
  1207. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1208. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1209. dtype::QuantizedS32(1.2 * 1.3), \
  1210. dtype::Quantized8Asymm(50.3f, (uint8_t)120), name);
  1211. float epsilon = 0.001;
  1212. #if MEGDNN_AARCH64
  1213. #if __ARM_FEATURE_DOTPROD
  1214. cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:48");
  1215. #else
  1216. cb("CONV1x1:AARCH64_QUINT8_K8X8X8:24");
  1217. #endif
  1218. #elif MEGDNN_ARMV7
  1219. epsilon = 1;
  1220. cb("CONV1x1:ARMV7_QUINT8_K4X8X8:48");
  1221. #endif
  1222. #undef cb
  1223. }
  1224. #endif
  1225. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1226. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32) {
  1227. UniformIntRNG rng{-50, 50};
  1228. float epsilon = 0.001;
  1229. #define cb(name) \
  1230. checker_conv_bias(get_conv_bias_1x1_args(true, true), handle(), &rng, \
  1231. epsilon, dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1232. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1233. dtype::QuantizedS32(1.2 * 1.3), {}, name);
  1234. #if MEGDNN_AARCH64
  1235. #if __ARM_FEATURE_DOTPROD
  1236. cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:24");
  1237. #else
  1238. cb("CONV1x1:AARCH64_QUINT8_K8X8X8:48");
  1239. #endif
  1240. #elif MEGDNN_ARMV7
  1241. #if __ARM_FEATURE_DOTPROD
  1242. cb("CONV1x1:AARCH32_QUINT8_K4X8X4:48");
  1243. #endif
  1244. cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24");
  1245. #endif
  1246. #undef cb
  1247. }
  1248. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) {
  1249. UniformIntRNG rng{-50, 50};
  1250. float epsilon = 0.001;
  1251. #define cb(name) \
  1252. checker_conv_bias(get_conv_bias_1x1_args(true, true), handle(), &rng, \
  1253. epsilon, dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, \
  1254. dtype::Int16{}, name);
  1255. #if MEGDNN_AARCH64
  1256. cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24");
  1257. cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24");
  1258. #elif MEGDNN_ARMV7
  1259. cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24");
  1260. cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48");
  1261. #endif
  1262. cb("CONV1x1:ARM_COMMON_INT8X8X16:48");
  1263. #undef cb
  1264. }
  1265. #endif
  1266. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) {
  1267. using namespace conv_bias;
  1268. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true);
  1269. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1270. #if MEGDNN_AARCH64
  1271. #if __ARM_FEATURE_DOTPROD
  1272. cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:48");
  1273. #else
  1274. cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24");
  1275. cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:24");
  1276. #endif
  1277. #elif MEGDNN_ARMV7
  1278. #if __ARM_FEATURE_DOTPROD
  1279. cb("CONV1x1:AARCH32_INT8_K6X8X4:48");
  1280. #endif
  1281. cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:24");
  1282. #endif
  1283. #if MEGDNN_ARMV7
  1284. cb("CONV1x1:ARMV7_INT8X8X32_K4X2X16:48");
  1285. #endif
  1286. #undef cb
  1287. }
  1288. #ifndef __ARM_FEATURE_DOTPROD
  1289. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) {
  1290. using namespace conv_bias;
  1291. std::vector<conv_bias::TestArg> args =
  1292. get_nchw44_conv_bias_args({1}, 1, true, true, true);
  1293. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1294. #if MEGDNN_AARCH64
  1295. cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24");
  1296. #elif MEGDNN_ARMV7
  1297. cb("CONV1x1:ARMV7_INT8X8X32_MK4_4X2X16:24");
  1298. #endif
  1299. #undef cb
  1300. UniformIntRNG rng{-50, 50};
  1301. float epsilon = 0.001;
  1302. #define cb(name) \
  1303. checker_conv_bias(get_nchw44_conv_bias_args({1}, 1, true, false, false), \
  1304. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1305. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1306. dtype::QuantizedS8(60.25f), name);
  1307. #if MEGDNN_AARCH64
  1308. cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24");
  1309. #elif MEGDNN_ARMV7
  1310. cb("CONV1x1:ARMV7_INT8X8X32_MK4_4X2X16:24");
  1311. #endif
  1312. #undef cb
  1313. }
  1314. #endif
  1315. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台