You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bias_multi_thread.cpp 60 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446
  1. /**
  2. * \file dnn/test/arm_common/conv_bias_multi_thread.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "test/arm_common/fixture.h"
  13. #include "test/common/benchmarker.h"
  14. #include "test/common/conv_bias.h"
  15. using namespace megdnn;
  16. using namespace test;
  17. using namespace conv_bias;
  18. std::vector<conv_bias::TestArg> get_int8_quint8_conv_bias_args(
  19. std::vector<size_t> kernel, size_t stride, bool no_pad, bool no_bias,
  20. bool no_nonlinemode) {
  21. using namespace conv_bias;
  22. using Param = param::ConvBias;
  23. using NLMode = param::ConvBias::NonlineMode;
  24. std::vector<TestArg> args;
  25. auto pack = [&](size_t n, size_t oc, size_t ic, size_t w, size_t h,
  26. size_t kernel, size_t stride, NLMode nlmode) {
  27. Param param;
  28. param.stride_h = stride;
  29. param.stride_w = stride;
  30. if (!no_pad) {
  31. param.pad_h = kernel / 2;
  32. param.pad_w = kernel / 2;
  33. } else {
  34. param.pad_h = 0;
  35. param.pad_w = 0;
  36. }
  37. param.nonlineMode = nlmode;
  38. args.emplace_back(param, TensorShape{n, ic, h, w},
  39. TensorShape{oc, ic, kernel, kernel}, TensorShape{});
  40. if (!no_bias) {
  41. args.emplace_back(param, TensorShape{n, ic, h, w},
  42. TensorShape{oc, ic, kernel, kernel},
  43. TensorShape{1, oc, 1, 1});
  44. }
  45. };
  46. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  47. if (!no_nonlinemode) {
  48. nonlinemode.emplace_back(NLMode::RELU);
  49. nonlinemode.emplace_back(NLMode::H_SWISH);
  50. }
  51. for (size_t n : {1, 2}) {
  52. for (auto nlmode : nonlinemode) {
  53. for (size_t ic : {1, 3, 7}) {
  54. for (size_t oc : {1, 3, 7}) {
  55. for (size_t size : {4, 6, 8, 14, 16, 18}) {
  56. for (size_t kern : kernel) {
  57. pack(n, oc, ic, size, size, kern, stride, nlmode);
  58. }
  59. }
  60. }
  61. }
  62. }
  63. }
  64. return args;
  65. }
  66. std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
  67. std::vector<size_t> kernel_vec, size_t stride, bool no_pad = false,
  68. bool no_bias = false, bool no_nonlinemode = false,
  69. bool is_input_nchw = false) {
  70. using namespace conv_bias;
  71. using NLMode = param::ConvBias::NonlineMode;
  72. std::vector<TestArg> args;
  73. auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w,
  74. size_t kernel, size_t stride, size_t group, NLMode nlmode,
  75. int any_pad = -1) {
  76. constexpr int pack_c = 4;
  77. const size_t pad = any_pad >= 0 ? any_pad : kernel / 2;
  78. auto bias_mode = no_bias ? megdnn::BiasMode::NO_BIAS
  79. : megdnn::BiasMode::BROADCAST_CHANNEL_BIAS;
  80. auto oc_per_group = oc / group;
  81. auto ic_per_group = ic / group;
  82. bool ok_group = (oc % group == 0 && ic % group == 0) &&
  83. oc_per_group % pack_c == 0 && oc_per_group > 0 &&
  84. ic_per_group > 0;
  85. bool nchw_disable = group > 1 || ic_per_group >= 4;
  86. bool nchw44_disable = ic_per_group % pack_c != 0;
  87. bool invalid_pad = (w + 2 * pad < kernel) || (h + 2 * pad < kernel);
  88. if (!(ok_group) || invalid_pad) {
  89. return;
  90. }
  91. if ((is_input_nchw && nchw_disable) ||
  92. (!is_input_nchw && nchw44_disable)) {
  93. return;
  94. }
  95. size_t kernel_h = kernel;
  96. size_t kernel_w = kernel;
  97. param::ConvBias param;
  98. param.format = param::ConvBias::Format::NCHW44;
  99. param.stride_h = stride;
  100. param.stride_w = stride;
  101. param.pad_h = pad;
  102. param.pad_w = pad;
  103. param.nonlineMode = nlmode;
  104. auto src_tensor_shape = TensorShape{n, ic / pack_c, h, w, pack_c};
  105. auto weight_tensor_shape = TensorShape{
  106. oc / pack_c, ic / pack_c, kernel_h, kernel_w, pack_c, pack_c};
  107. auto bias_tensor_shape = TensorShape{};
  108. if (bias_mode == megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) {
  109. bias_tensor_shape = {1, oc / pack_c, 1, 1, pack_c};
  110. }
  111. if (group == 1) {
  112. param.sparse = param::ConvBias::Sparse::DENSE;
  113. } else if (group > 1 && ic / group == 1 && oc / group == 1) {
  114. megdnn_assert(0, "not support channel wise");
  115. param.sparse = param::ConvBias::Sparse::GROUP;
  116. weight_tensor_shape = TensorShape{group / pack_c, 1, 1,
  117. kernel_h, kernel_w, pack_c};
  118. } else if (group > 1 && oc_per_group % pack_c == 0 && oc / group > 0 &&
  119. ic_per_group % pack_c == 0 && ic / group > 0) {
  120. param.sparse = param::ConvBias::Sparse::GROUP;
  121. weight_tensor_shape = TensorShape{group,
  122. oc_per_group / pack_c,
  123. ic_per_group / pack_c,
  124. kernel_h,
  125. kernel_w,
  126. pack_c,
  127. pack_c};
  128. }
  129. if (is_input_nchw) {
  130. src_tensor_shape = TensorShape{n, ic, h, w};
  131. weight_tensor_shape =
  132. TensorShape{oc / pack_c, kernel_h, kernel_w, ic, pack_c};
  133. }
  134. args.emplace_back(param, src_tensor_shape, weight_tensor_shape,
  135. bias_tensor_shape);
  136. };
  137. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  138. if (!no_nonlinemode) {
  139. nonlinemode.emplace_back(NLMode::RELU);
  140. nonlinemode.emplace_back(NLMode::H_SWISH);
  141. }
  142. for (auto nlmode : nonlinemode)
  143. for (size_t n : {1, 2})
  144. for (size_t kernel : kernel_vec)
  145. for (size_t oc : {4, 12, 32})
  146. for (size_t ic : {1, 3, 4, 12, 32})
  147. for (size_t h : {3, 5, 12})
  148. for (size_t w : {7, 16, 23}) {
  149. for (size_t group = 1;
  150. group <= std::min(oc, ic); ++group) {
  151. pack(n, oc, ic, h, w, kernel, stride, group,
  152. nlmode);
  153. }
  154. }
  155. return args;
  156. }
  157. std::vector<conv_bias::TestArg> get_int8_quint8_nchw44_channel_wise_args(
  158. std::vector<size_t> kernel, size_t stride, bool no_bias,
  159. bool no_nonlinemode) {
  160. using namespace conv_bias;
  161. using Param = param::ConvBias;
  162. using NLMode = param::ConvBias::NonlineMode;
  163. std::vector<TestArg> args;
  164. auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
  165. size_t stride, NLMode nlmode, bool pad) {
  166. Param param;
  167. param.stride_h = stride;
  168. param.stride_w = stride;
  169. if (pad) {
  170. param.pad_h = kernel / 2;
  171. param.pad_w = kernel / 2;
  172. } else {
  173. param.pad_h = 0;
  174. param.pad_w = 0;
  175. }
  176. param.nonlineMode = nlmode;
  177. param.format = param::ConvBias::Format::NCHW44;
  178. param.sparse = param::ConvBias::Sparse::GROUP;
  179. args.emplace_back(param, TensorShape{n, group, h, w, 4},
  180. TensorShape{group, 1, 1, kernel, kernel, 4},
  181. TensorShape{});
  182. if (!no_bias) {
  183. args.emplace_back(param, TensorShape{n, group, h, w, 4},
  184. TensorShape{group, 1, 1, kernel, kernel, 4},
  185. TensorShape{1, group, 1, 1, 4});
  186. }
  187. };
  188. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  189. if (!no_nonlinemode) {
  190. nonlinemode.emplace_back(NLMode::RELU);
  191. nonlinemode.emplace_back(NLMode::H_SWISH);
  192. }
  193. for (size_t n : {1, 2}) {
  194. for (auto nlmode : nonlinemode) {
  195. for (bool pad : {true}) {
  196. for (size_t group : {1, 2, 4, 7, 128}) {
  197. for (size_t size : {4, 5, 6, 7, 8, 9, 10, 15, 40}) {
  198. for (size_t kern : kernel) {
  199. pack(n, group, size, size, kern, stride, nlmode,
  200. pad);
  201. }
  202. }
  203. }
  204. }
  205. for (bool pad : {false}) {
  206. for (size_t group : {1, 2, 7, 128}) {
  207. for (size_t size : {7, 8, 9, 10, 15, 40}) {
  208. for (size_t kern : kernel) {
  209. pack(n, group, size, size, kern, stride, nlmode,
  210. pad);
  211. }
  212. }
  213. }
  214. }
  215. }
  216. }
  217. return args;
  218. }
  219. void checker_conv_bias_qint8x8x8(std::vector<conv_bias::TestArg> args,
  220. Handle* handle, const char* algo_name) {
  221. Checker<ConvBias> checker(handle);
  222. checker.set_before_exec_callback(
  223. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  224. #if MEGDNN_ARMV7
  225. checker.set_epsilon(1);
  226. #endif
  227. UniformIntRNG rng{-50, 50};
  228. checker.set_dtype(0, dtype::QuantizedS8(0.41113496f))
  229. .set_dtype(1, dtype::QuantizedS8(0.01887994f))
  230. .set_dtype(2, dtype::QuantizedS32(0.41113496f * 0.01887994f))
  231. .set_dtype(4, dtype::QuantizedS8(0.49550694f))
  232. .set_rng(0, &rng)
  233. .set_rng(1, &rng)
  234. .set_rng(2, &rng);
  235. for (auto&& arg : args) {
  236. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  237. }
  238. }
  239. void checker_conv_bias_qint8x8x32(std::vector<conv_bias::TestArg> args,
  240. Handle* handle, const char* algo_name) {
  241. Checker<ConvBias> checker(handle);
  242. UniformIntRNG rng{-50, 50};
  243. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  244. .set_dtype(1, dtype::QuantizedS8(2.5f))
  245. .set_dtype(2, dtype::QuantizedS32(6.25f))
  246. .set_dtype(4, {});
  247. checker.set_before_exec_callback(
  248. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  249. for (auto&& arg : args) {
  250. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  251. }
  252. }
  253. void checker_conv_bias_quint8x8x8(std::vector<conv_bias::TestArg> args,
  254. Handle* handle, const char* algo_name) {
  255. Checker<ConvBias> checker(handle);
  256. checker.set_before_exec_callback(
  257. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  258. UniformIntRNG rng(0, 255);
  259. checker.set_dtype(0, dtype::Quantized8Asymm(0.2f, 100))
  260. .set_dtype(1, dtype::Quantized8Asymm(0.2f, 120))
  261. .set_dtype(2, dtype::QuantizedS32(0.04f))
  262. .set_dtype(4, dtype::Quantized8Asymm(1.4f, 110))
  263. .set_rng(0, &rng)
  264. .set_rng(1, &rng)
  265. .set_rng(2, &rng);
  266. for (auto&& arg : args) {
  267. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  268. }
  269. }
  270. void checker_conv_bias_quint8x8x32(std::vector<conv_bias::TestArg> args,
  271. Handle* handle, const char* algo_name) {
  272. Checker<ConvBias> checker(handle);
  273. checker.set_before_exec_callback(
  274. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  275. NormalRNG rng(128.f);
  276. checker.set_rng(0, &rng).set_rng(1, &rng);
  277. checker.set_dtype(0, dtype::Quantized8Asymm(1.2f, (uint8_t)127))
  278. .set_dtype(1, dtype::Quantized8Asymm(1.3f, (uint8_t)129))
  279. .set_dtype(2, dtype::QuantizedS32(1.2 * 1.3))
  280. .set_dtype(4, {});
  281. for (auto&& arg : args) {
  282. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  283. }
  284. }
  285. void checker_conv_bias_int8x8x32_multi(std::vector<conv_bias::TestArg> args,
  286. Handle* handle, const char* algo_name) {
  287. Checker<ConvBias> checker(handle);
  288. checker.set_before_exec_callback(
  289. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  290. checker.set_dtype(0, dtype::Int8());
  291. checker.set_dtype(1, dtype::Int8());
  292. checker.set_dtype(2, dtype::Int32());
  293. checker.set_dtype(4, dtype::Int32());
  294. for (auto&& arg : args) {
  295. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  296. }
  297. }
  298. /**********************************F32 direct************************/
  299. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_LARGE_GROUP) {
  300. check_conv_bias(
  301. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  302. handle(), "F32DIRECT_LARGE_GROUP");
  303. }
  304. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_SMALL_GROUP) {
  305. check_conv_bias(
  306. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  307. handle(), "F32DIRECT_SMALL_GROUP");
  308. }
  309. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_LARGE_GROUP) {
  310. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  311. handle(), "F32STRD1_LARGE_GROUP");
  312. }
  313. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_SMALL_GROUP) {
  314. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  315. handle(), "F32STRD1_SMALL_GROUP");
  316. }
  317. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_LARGE_GROUP) {
  318. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  319. handle(), "F32STRD2_LARGE_GROUP");
  320. }
  321. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) {
  322. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  323. handle(), "F32STRD2_SMALL_GROUP");
  324. }
  325. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32) {
  326. check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false,
  327. false, true),
  328. handle(), "F32_CONV_NCHW_NCHW44");
  329. }
  330. /**********************************F16 direct************************/
  331. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  332. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_LARGE_GROUP) {
  333. NormalRNG rng(1);
  334. checker_conv_bias_f16(
  335. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  336. handle(), rng, "F16DIRECT_LARGE_GROUP", 0.03);
  337. }
  338. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_SMALL_GROUP) {
  339. NormalRNG rng(1);
  340. checker_conv_bias_f16(
  341. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  342. handle(), rng, "F16DIRECT_SMALL_GROUP", 0.03);
  343. }
  344. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_LARGE_GROUP) {
  345. NormalRNG rng(1);
  346. checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false),
  347. handle(), rng, "F16STRD1_LARGE_GROUP", 0.03);
  348. }
  349. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_SMALL_GROUP) {
  350. NormalRNG rng(1);
  351. checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false),
  352. handle(), rng, "F16STRD1_SMALL_GROUP", 0.03);
  353. }
  354. #endif
  355. /**********************************algo 8816 direct************************/
  356. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_LARGE_GROUP) {
  357. checker_conv_bias_int8x8x16(
  358. get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(),
  359. "I8816DIRECT_LARGE_GROUP");
  360. }
  361. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_SMALL_GROUP) {
  362. checker_conv_bias_int8x8x16(
  363. get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(),
  364. "I8816DIRECT_SMALL_GROUP");
  365. }
  366. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_LARGE_GROUP) {
  367. checker_conv_bias_int8x8x16(
  368. get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
  369. "I8816STRD2_LARGE_GROUP");
  370. }
  371. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_SMALL_GROUP) {
  372. checker_conv_bias_int8x8x16(
  373. get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
  374. "I8816STRD2_SMALL_GROUP");
  375. }
  376. /**********************************algo 8-8-32 direct************************/
  377. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_LARGE_GROUP) {
  378. checker_conv_bias_int8x8x32_multi(
  379. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  380. "S8STRD1_LARGE_GROUP");
  381. }
  382. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_SMALL_GROUP) {
  383. checker_conv_bias_int8x8x32_multi(
  384. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  385. "S8STRD1_SMALL_GROUP");
  386. }
  387. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_LARGE_GROUP) {
  388. checker_conv_bias_int8x8x32_multi(
  389. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  390. "S8STRD2_LARGE_GROUP");
  391. }
  392. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_SMALL_GROUP) {
  393. checker_conv_bias_int8x8x32_multi(
  394. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  395. "S8STRD2_SMALL_GROUP");
  396. }
  397. TEST_F(ARM_COMMON_MULTI_THREADS,
  398. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT1_NCHW44) {
  399. checker_conv_bias_int8x8x32_multi(
  400. get_int8_quint8_nchw44_channel_wise_args({2, 3, 5}, 1, false, true),
  401. handle(), "S8_CHAN_WISE_STRD1_NCHW44");
  402. }
  403. TEST_F(ARM_COMMON_MULTI_THREADS,
  404. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT2_NCHW44) {
  405. checker_conv_bias_int8x8x32_multi(
  406. get_int8_quint8_nchw44_channel_wise_args({2, 3, 5}, 2, false, true),
  407. handle(), "S8_CHAN_WISE_STRD2_NCHW44");
  408. }
  409. /********************************qint8 direct******************************/
  410. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_LARGE_GROUP) {
  411. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  412. {2, 3, 5, 7}, 1, false, false, false),
  413. handle(), "S8STRD1_LARGE_GROUP");
  414. }
  415. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_SMALL_GROUP) {
  416. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  417. {2, 3, 5, 7}, 1, false, false, false),
  418. handle(), "S8STRD1_SMALL_GROUP");
  419. }
  420. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_LARGE_GROUP) {
  421. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  422. {2, 3, 5, 7}, 2, false, false, false),
  423. handle(), "S8STRD2_LARGE_GROUP");
  424. }
  425. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_SMALL_GROUP) {
  426. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  427. {2, 3, 5, 7}, 2, false, false, false),
  428. handle(), "S8STRD2_SMALL_GROUP");
  429. }
  430. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) {
  431. checker_conv_bias_qint8x8x8(
  432. get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  433. handle(), "S8_NCHW44_DIRECT_STRD1");
  434. }
  435. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) {
  436. checker_conv_bias_qint8x8x8(
  437. get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  438. handle(), "S8_NCHW44_DIRECT_STRD2");
  439. }
  440. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) {
  441. checker_conv_bias_qint8x8x8(get_int8_quint8_nchw44_channel_wise_args(
  442. {2, 3, 5}, 1, false, false),
  443. handle(), "S8_CHAN_WISE_STRD1_NCHW44");
  444. }
  445. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) {
  446. checker_conv_bias_qint8x8x8(get_int8_quint8_nchw44_channel_wise_args(
  447. {2, 3, 5}, 2, false, false),
  448. handle(), "S8_CHAN_WISE_STRD2_NCHW44");
  449. }
  450. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44) {
  451. checker_conv_bias_qint8x8x8(
  452. get_nchw44_conv_bias_args({3, 5, 7}, 2, false, false, false, true),
  453. handle(), "S8_CONV_NCHW_NCHW44");
  454. }
  455. /*****************************quint8 direct****************************/
  456. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_LARGE_GROUP) {
  457. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  458. {2, 3, 5, 7}, 1, false, false, false),
  459. handle(), "QU8STRD1_LARGE_GROUP");
  460. }
  461. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_SMALL_GROUP) {
  462. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  463. {2, 3, 5, 7}, 1, false, false, false),
  464. handle(), "QU8STRD1_SMALL_GROUP");
  465. }
  466. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_LARGE_GROUP) {
  467. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  468. {2, 3, 5, 7}, 2, false, false, false),
  469. handle(), "QU8STRD2_LARGE_GROUP");
  470. }
  471. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) {
  472. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  473. {2, 3, 5, 7}, 2, false, false, false),
  474. handle(), "QU8STRD2_SMALL_GROUP");
  475. }
  476. /****************************dot qint8 direct*************************/
  477. #if __ARM_FEATURE_DOTPROD
  478. TEST_F(ARM_COMMON_MULTI_THREADS,
  479. CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) {
  480. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  481. {2, 3, 5, 7}, 1, false, false, false),
  482. handle(), "ARMDOTS8STRD1_LARGE_GROUP");
  483. }
  484. TEST_F(ARM_COMMON_MULTI_THREADS,
  485. CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) {
  486. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  487. {2, 3, 5, 7}, 1, false, false, false),
  488. handle(), "ARMDOTS8STRD1_SMALL_GROUP");
  489. }
  490. TEST_F(ARM_COMMON_MULTI_THREADS,
  491. CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) {
  492. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  493. {2, 3, 5, 7}, 2, false, false, false),
  494. handle(), "ARMDOTS8STRD2_LARGE_GROUP");
  495. }
  496. TEST_F(ARM_COMMON_MULTI_THREADS,
  497. CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) {
  498. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  499. {2, 3, 5, 7}, 2, false, false, false),
  500. handle(), "ARMDOTS8STRD2_SMALL_GROUP");
  501. }
  502. /****************************dot 8-8-32 direct*************************/
  503. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_LARGE_GROUP) {
  504. checker_conv_bias_qint8x8x32(
  505. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  506. "ARMDOTS8STRD1_LARGE_GROUP");
  507. }
  508. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_SMALL_GROUP) {
  509. checker_conv_bias_qint8x8x32(
  510. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  511. "ARMDOTS8STRD1_SMALL_GROUP");
  512. }
  513. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_LARGE_GROUP) {
  514. checker_conv_bias_qint8x8x32(
  515. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  516. "ARMDOTS8STRD2_LARGE_GROUP");
  517. }
  518. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_SMALL_GROUP) {
  519. checker_conv_bias_qint8x8x32(
  520. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  521. "ARMDOTS8STRD2_SMALL_GROUP");
  522. }
  523. /******************************dot quint8*****************************/
  524. TEST_F(ARM_COMMON_MULTI_THREADS,
  525. CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) {
  526. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  527. {2, 3, 5, 7}, 1, false, false, false),
  528. handle(), "ARMDOTU8STRD1_LARGE_GROUP");
  529. }
  530. TEST_F(ARM_COMMON_MULTI_THREADS,
  531. CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) {
  532. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  533. {2, 3, 5, 7}, 1, false, false, false),
  534. handle(), "ARMDOTU8STRD1_SMALL_GROUP");
  535. }
  536. TEST_F(ARM_COMMON_MULTI_THREADS,
  537. CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) {
  538. checker_conv_bias_quint8x8x8(
  539. get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false),
  540. handle(), "ARMDOTU8STRD2_LARGE_GROUP");
  541. }
  542. TEST_F(ARM_COMMON_MULTI_THREADS,
  543. CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) {
  544. checker_conv_bias_quint8x8x8(
  545. get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false),
  546. handle(), "ARMDOTU8STRD2_SMALL_GROUP");
  547. }
  548. /******************************dot quint8x8x32***********************/
  549. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_LARGE_GROUP) {
  550. checker_conv_bias_quint8x8x32(
  551. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  552. "ARMDOTU8STRD1_LARGE_GROUP");
  553. }
  554. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_SMALL_GROUP) {
  555. checker_conv_bias_quint8x8x32(
  556. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  557. "ARMDOTU8STRD1_SMALL_GROUP");
  558. }
  559. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_LARGE_GROUP) {
  560. checker_conv_bias_quint8x8x32(
  561. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  562. "ARMDOTU8STRD2_LARGE_GROUP");
  563. }
  564. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_SMALL_GROUP) {
  565. checker_conv_bias_quint8x8x32(
  566. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  567. "ARMDOTU8STRD2_SMALL_GROUP");
  568. }
  569. #endif
  570. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4) {
  571. using namespace conv_bias;
  572. std::vector<TestArg> args = get_winograd_mk_packed_args();
  573. Checker<ConvBiasForward> checker(handle());
  574. check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4);
  575. }
  576. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63) {
  577. using namespace conv_bias;
  578. std::vector<TestArg> args = get_winograd_args(3);
  579. Checker<ConvBiasForward> checker(handle());
  580. check_winograd("1:6:32", checker, args);
  581. }
  582. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) {
  583. using namespace conv_bias;
  584. std::vector<TestArg> args = get_winograd_mk_packed_args();
  585. Checker<ConvBiasForward> checker(handle());
  586. check_winograd("4:6:32", checker, args, param::MatrixMul::Format::MK4);
  587. }
  588. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54) {
  589. using namespace conv_bias;
  590. std::vector<TestArg> args = get_winograd_args(4);
  591. Checker<ConvBiasForward> checker(handle());
  592. check_winograd("1:5:32", checker, args);
  593. }
  594. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45) {
  595. using namespace conv_bias;
  596. std::vector<TestArg> args = get_winograd_args(5);
  597. Checker<ConvBiasForward> checker(handle());
  598. check_winograd("1:4:32", checker, args);
  599. }
  600. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) {
  601. using namespace conv_bias;
  602. std::vector<TestArg> args = get_winograd_args(3);
  603. Checker<ConvBiasForward> checker(handle());
  604. auto extra_impl = [](const TensorNDArray& tensors, uint32_t m,
  605. param::ConvBias param, Handle* handle) {
  606. megdnn_assert(param.format == param::ConvBias::Format::NCHW);
  607. auto winograd_preprocess_opr =
  608. handle->create_operator<WinogradFilterPreprocess>();
  609. winograd_preprocess_opr->param().output_block_size = m;
  610. TensorLayout filter_transform_layout;
  611. winograd_preprocess_opr->deduce_layout(tensors[1].layout,
  612. filter_transform_layout);
  613. size_t winograd_preprocess_workspace_in_bytes =
  614. winograd_preprocess_opr->get_workspace_in_bytes(
  615. tensors[1].layout, filter_transform_layout);
  616. auto conv_bias_opr = handle->create_operator<ConvBias>();
  617. conv_bias_opr->param() = param;
  618. conv_bias_opr->param().format = param::ConvBias::Format::NCHW_WINOGRAD;
  619. conv_bias_opr->param().output_block_size = m;
  620. size_t conv_bias_workspace_in_bytes =
  621. conv_bias_opr->get_workspace_in_bytes(
  622. tensors[0].layout, filter_transform_layout,
  623. tensors[2].layout, tensors[3].layout,
  624. tensors[4].layout);
  625. WorkspaceBundle wb(nullptr, {filter_transform_layout.span().dist_byte(),
  626. conv_bias_workspace_in_bytes,
  627. winograd_preprocess_workspace_in_bytes});
  628. wb.set(malloc(wb.total_size_in_bytes()));
  629. TensorND filter_transform_tensor(wb.get(0),
  630. std::move(filter_transform_layout));
  631. winograd_preprocess_opr->exec(tensors[1], filter_transform_tensor,
  632. wb.get_workspace(2));
  633. conv_bias_opr->exec(tensors[0], filter_transform_tensor, tensors[2],
  634. tensors[3], tensors[4], wb.get_workspace(1));
  635. free(wb.ptr());
  636. };
  637. auto run = [&checker, &extra_impl](
  638. Handle* handle, const std::vector<TestArg>& args,
  639. const std::vector<size_t>& out_size, DType A_dtype,
  640. DType B_dtype, DType C_dtype, DType D_dtype,
  641. const float eps) {
  642. for (auto&& arg : args) {
  643. for (uint32_t m : out_size) {
  644. checker.set_extra_opr_impl(std::bind(extra_impl,
  645. std::placeholders::_1, m,
  646. arg.param, handle));
  647. checker.set_dtype(0, A_dtype)
  648. .set_dtype(1, B_dtype)
  649. .set_dtype(2, C_dtype)
  650. .set_dtype(4, D_dtype)
  651. .set_epsilon(eps)
  652. .set_param(arg.param)
  653. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  654. }
  655. }
  656. };
  657. run(handle(), args, {6}, dtype::Float32(), dtype::Float32(),
  658. dtype::Float32(), dtype::Float32(), 1e-3f);
  659. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  660. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  661. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  662. run(handle(), args, {6}, dtype::Float16(), dtype::Float16(),
  663. dtype::Float16(), dtype::Float16(), 0.35f);
  664. #endif
  665. }
  666. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1) {
  667. using namespace conv_bias;
  668. Checker<ConvBiasForward> checker(handle());
  669. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  670. const std::vector<size_t>& out_size, DType A_dtype,
  671. DType B_dtype, DType C_dtype, DType D_dtype,
  672. param::MatrixMul::Format format, float eps) {
  673. for (auto&& arg : args) {
  674. for (uint32_t m : out_size) {
  675. checker.set_extra_opr_impl(std::bind(
  676. winograd_algo_extra_impl, std::placeholders::_1, m,
  677. arg.param, handle, format));
  678. checker.set_dtype(0, A_dtype)
  679. .set_dtype(1, B_dtype)
  680. .set_dtype(2, C_dtype)
  681. .set_dtype(4, D_dtype)
  682. .set_epsilon(eps)
  683. .set_param(arg.param)
  684. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  685. }
  686. }
  687. };
  688. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  689. std::vector<TestArg> args_first_half(args.begin(),
  690. args.begin() + args.size() / 2);
  691. run(handle(), args_first_half, {2, 6}, dtype::Float32{}, dtype::Float32{},
  692. dtype::Float32{}, dtype::Float32{}, param::MatrixMul::Format::MK4,
  693. 1e-3f);
  694. }
  695. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2) {
  696. using namespace conv_bias;
  697. Checker<ConvBiasForward> checker(handle());
  698. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  699. const std::vector<size_t>& out_size, DType A_dtype,
  700. DType B_dtype, DType C_dtype, DType D_dtype,
  701. param::MatrixMul::Format format, float eps) {
  702. for (auto&& arg : args) {
  703. for (uint32_t m : out_size) {
  704. checker.set_extra_opr_impl(std::bind(
  705. winograd_algo_extra_impl, std::placeholders::_1, m,
  706. arg.param, handle, format));
  707. checker.set_dtype(0, A_dtype)
  708. .set_dtype(1, B_dtype)
  709. .set_dtype(2, C_dtype)
  710. .set_dtype(4, D_dtype)
  711. .set_epsilon(eps)
  712. .set_param(arg.param)
  713. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  714. }
  715. }
  716. };
  717. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  718. std::vector<TestArg> args_second_half(args.begin() + args.size() / 2,
  719. args.end());
  720. run(handle(), args_second_half, {2, 6}, dtype::Float32{}, dtype::Float32{},
  721. dtype::Float32{}, dtype::Float32{}, param::MatrixMul::Format::MK4,
  722. 1e-3f);
  723. }
  724. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  725. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F16) {
  726. using namespace conv_bias;
  727. Checker<ConvBiasForward> checker(handle());
  728. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  729. const std::vector<size_t>& out_size, DType A_dtype,
  730. DType B_dtype, DType C_dtype, DType D_dtype,
  731. param::MatrixMul::Format format, float eps) {
  732. for (auto&& arg : args) {
  733. for (uint32_t m : out_size) {
  734. checker.set_extra_opr_impl(std::bind(
  735. winograd_algo_extra_impl, std::placeholders::_1, m,
  736. arg.param, handle, format));
  737. checker.set_dtype(0, A_dtype)
  738. .set_dtype(1, B_dtype)
  739. .set_dtype(2, C_dtype)
  740. .set_dtype(4, D_dtype)
  741. .set_epsilon(eps)
  742. .set_param(arg.param)
  743. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  744. }
  745. }
  746. };
  747. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  748. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  749. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  750. run(handle(), args, {2}, dtype::Float16{}, dtype::Float16{},
  751. dtype::Float16{}, dtype::Float16{}, param::MatrixMul::Format::MK8,
  752. 0.25);
  753. }
  754. #endif
  755. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_INT8) {
  756. using namespace conv_bias;
  757. Checker<ConvBiasForward> checker(handle());
  758. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  759. const std::vector<size_t>& out_size, DType A_dtype,
  760. DType B_dtype, DType C_dtype, DType D_dtype,
  761. param::MatrixMul::Format format, float eps) {
  762. for (auto&& arg : args) {
  763. for (uint32_t m : out_size) {
  764. checker.set_extra_opr_impl(std::bind(
  765. winograd_algo_extra_impl, std::placeholders::_1, m,
  766. arg.param, handle, format));
  767. checker.set_dtype(0, A_dtype)
  768. .set_dtype(1, B_dtype)
  769. .set_dtype(2, C_dtype)
  770. .set_dtype(4, D_dtype)
  771. .set_epsilon(eps)
  772. .set_param(arg.param)
  773. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  774. }
  775. }
  776. };
  777. #if MEGDNN_AARCH64
  778. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  779. #else
  780. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  781. #endif
  782. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  783. ssprintf("WINOGRAD:%s:8:2:32", matmul_name).c_str()));
  784. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  785. std::vector<TestArg> quantized_args =
  786. get_quantized_winograd_mk_packed_args(8);
  787. UniformIntRNG int_rng{-50, 50};
  788. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  789. run(handle(), quantized_args, {2}, dtype::QuantizedS8(2.5f),
  790. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f),
  791. dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3);
  792. }
  793. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  794. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F23) {
  795. using namespace conv_bias;
  796. std::vector<TestArg> args = get_winograd_mk_packed_args();
  797. Checker<ConvBiasForward> checker(handle());
  798. check_winograd_fp16("1:2:32", checker, args, NULL, 0.08);
  799. }
  800. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_1) {
  801. using namespace conv_bias;
  802. std::vector<TestArg> args = get_winograd_args(5);
  803. std::vector<TestArg> args_head_half(args.begin(),
  804. args.begin() + args.size() / 2);
  805. Checker<ConvBiasForward> checker(handle());
  806. //! fp16 range -1.0 ~ 1.0
  807. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  808. check_winograd_fp16("1:4:32", checker, args_head_half, rng, 0.25);
  809. }
  810. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_2) {
  811. using namespace conv_bias;
  812. std::vector<TestArg> args = get_winograd_args(5);
  813. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2,
  814. args.end());
  815. Checker<ConvBiasForward> checker(handle());
  816. //! fp16 range -1.0 ~ 1.0
  817. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  818. check_winograd_fp16("1:4:32", checker, args_back_half, rng, 0.25);
  819. }
  820. //! FIXME: This test may be failed if run `ARM_COMMON.CONV_BIAS_WINOGRAD*`, but
  821. //! it will pass when run single testcase
  822. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F63) {
  823. using namespace conv_bias;
  824. std::vector<TestArg> args = get_winograd_args(3);
  825. Checker<ConvBiasForward> checker(handle());
  826. //! fp16 range -1.0 ~ 1.0
  827. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  828. check_winograd_fp16("1:6:32", checker, args, rng, 0.3);
  829. }
  830. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_1) {
  831. using namespace conv_bias;
  832. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  833. std::vector<TestArg> args_head_half(args.begin(),
  834. args.begin() + args.size() / 2);
  835. Checker<ConvBiasForward> checker(handle());
  836. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  837. check_winograd_fp16("8:2:32", checker, args_head_half, rng, 0.25,
  838. param::MatrixMul::Format::MK8);
  839. }
  840. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_2) {
  841. using namespace conv_bias;
  842. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  843. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2,
  844. args.end());
  845. Checker<ConvBiasForward> checker(handle());
  846. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  847. check_winograd_fp16("8:2:32", checker, args_back_half, rng, 0.25,
  848. param::MatrixMul::Format::MK8);
  849. }
  850. #endif
  851. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_INT8_8X8) {
  852. using namespace conv_bias;
  853. std::vector<TestArg> args = get_quantized_winograd_mk_packed_args(8);
  854. Checker<ConvBiasForward> checker(handle());
  855. UniformIntRNG rng{-50, 50};
  856. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  857. .set_dtype(1, dtype::QuantizedS8(2.5f))
  858. .set_dtype(2, dtype::QuantizedS32(6.25f))
  859. .set_dtype(4, dtype::QuantizedS8(60.25f))
  860. .set_rng(0, &rng)
  861. .set_rng(1, &rng)
  862. .set_rng(2, &rng);
  863. check_winograd("8:2:32", checker, args, param::MatrixMul::Format::MK8);
  864. }
  865. void checker_conv_bias(std::vector<conv_bias::TestArg> args, Handle* handle,
  866. RNG* rng, float epsilon, DType type0, DType type1,
  867. DType type2, DType type3, const char* algo_name) {
  868. using namespace conv_bias;
  869. Checker<ConvBias> checker(handle);
  870. checker.set_before_exec_callback(
  871. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  872. checker.set_dtype(0, type0);
  873. checker.set_dtype(1, type1);
  874. checker.set_dtype(2, type2);
  875. checker.set_dtype(4, type3);
  876. checker.set_epsilon(epsilon);
  877. if (NULL != rng) {
  878. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng).set_rng(3, rng);
  879. }
  880. for (auto&& arg : args) {
  881. checker.set_param(arg.param).execs(
  882. {arg.src, arg.filter, arg.bias, {}, {}});
  883. }
  884. }
  885. // clang-format off
  886. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE2) {
  887. #define cb(name) \
  888. check_conv_bias( \
  889. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 2, false, false, false), \
  890. handle(), name);
  891. #if MEGDNN_AARCH64
  892. cb("IM2COLMATMUL:AARCH64_F32K8X12X1")
  893. cb("IM2COLMATMUL:AARCH64_F32K4X16X1")
  894. cb("IM2COLMATMUL:FB_F32_K8X12X1")
  895. #elif MEGDNN_ARMV7
  896. cb("IM2COLMATMUL:ARMV7_F32")
  897. #endif
  898. #undef cb
  899. }
  900. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE1) {
  901. #define cb(name) \
  902. check_conv_bias( \
  903. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, false), \
  904. handle(), name);
  905. #if MEGDNN_AARCH64
  906. cb("IM2COLMATMUL:AARCH64_F32K8X12X1")
  907. cb("IM2COLMATMUL:AARCH64_F32K4X16X1")
  908. cb("IM2COLMATMUL:FB_F32_K8X12X1")
  909. #elif MEGDNN_ARMV7
  910. cb("IM2COLMATMUL:ARMV7_F32")
  911. cb("IM2COLMATMUL:FB_F32_K8X12X1")
  912. #endif
  913. #undef cb
  914. }
  915. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) {
  916. UniformIntRNG rng{-50, 50};
  917. #define cb(name) \
  918. checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
  919. false, true, true), \
  920. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  921. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  922. dtype::QuantizedS8(60.25f), name); \
  923. checker_conv_bias( \
  924. get_conv_bias_args({1}, 2, false, false, false, true, true), \
  925. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  926. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  927. dtype::QuantizedS8(60.25f), name);
  928. float epsilon = 0.001;
  929. #if MEGDNN_AARCH64
  930. #if __ARM_FEATURE_DOTPROD
  931. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD");
  932. #else
  933. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8");
  934. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16");
  935. #endif
  936. #elif MEGDNN_ARMV7
  937. epsilon = 1;
  938. cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8");
  939. #endif
  940. #undef cb
  941. }
  942. // clang-format on
  943. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  944. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) {
  945. NormalRNG rng(128.f);
  946. #define cb(name) \
  947. checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
  948. false, true, true), \
  949. handle(), &rng, epsilon, \
  950. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  951. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  952. dtype::QuantizedS32(1.2 * 1.3), \
  953. dtype::Quantized8Asymm(50.3f, (uint8_t)120), name); \
  954. checker_conv_bias( \
  955. get_conv_bias_args({1}, 2, false, false, false, true, true), \
  956. handle(), &rng, epsilon, \
  957. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  958. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  959. dtype::QuantizedS32(1.2 * 1.3), \
  960. dtype::Quantized8Asymm(50.3f, (uint8_t)120), name);
  961. float epsilon = 0.001;
  962. #if MEGDNN_AARCH64
  963. #if __ARM_FEATURE_DOTPROD
  964. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD");
  965. #else
  966. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8");
  967. #endif
  968. #elif MEGDNN_ARMV7
  969. epsilon = 1;
  970. cb("IM2COLMATMUL:ARMV7_QUINT8_K4X8X8");
  971. #endif
  972. #undef cb
  973. }
  974. #endif
  975. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  976. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) {
  977. UniformIntRNG rng{-50, 50};
  978. float epsilon = 0.001;
  979. #define cb(name) \
  980. checker_conv_bias( \
  981. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \
  982. handle(), &rng, epsilon, \
  983. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  984. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  985. dtype::QuantizedS32(1.2 * 1.3), {}, name); \
  986. checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \
  987. &rng, epsilon, \
  988. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  989. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  990. dtype::QuantizedS32(1.2 * 1.3), {}, name);
  991. #if MEGDNN_AARCH64
  992. #if __ARM_FEATURE_DOTPROD
  993. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD");
  994. #else
  995. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8");
  996. #endif
  997. #elif MEGDNN_ARMV7
  998. #if __ARM_FEATURE_DOTPROD
  999. cb("IM2COLMATMUL:AARCH32_QUINT8_K4X8X4");
  1000. #endif
  1001. cb("IM2COLMATMUL:ARMV7_QUINT8_K4X8X8");
  1002. #endif
  1003. #undef cb
  1004. }
  1005. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) {
  1006. UniformIntRNG rng{-50, 50};
  1007. float epsilon = 0.001;
  1008. #define cb(name) \
  1009. checker_conv_bias( \
  1010. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \
  1011. handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
  1012. dtype::Int16{}, dtype::Int16{}, name); \
  1013. checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \
  1014. &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
  1015. dtype::Int16{}, dtype::Int16{}, name);
  1016. #if MEGDNN_AARCH64
  1017. cb("IM2COLMATMUL:AARCH64_INT8X8X16_K8X8X8");
  1018. cb("IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16");
  1019. cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16");
  1020. #elif MEGDNN_ARMV7
  1021. cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16");
  1022. cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8");
  1023. cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X2X16");
  1024. #endif
  1025. #undef cb
  1026. }
  1027. #endif
  1028. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  1029. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP16) {
  1030. using namespace conv_bias;
  1031. param::ConvBias cur_param;
  1032. std::vector<conv_bias::TestArg> args =
  1033. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, false);
  1034. std::vector<conv_bias::TestArg> args1 =
  1035. get_conv_bias_args({1}, 2, false, false, false);
  1036. args.insert(args.begin(), args1.begin(), args1.end());
  1037. NormalRNG rng(1);
  1038. #define cb(name) \
  1039. checker_conv_bias(args, handle(), &rng, 0.03, dtype::Float16{}, \
  1040. dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, \
  1041. name);
  1042. #if MEGDNN_AARCH64
  1043. cb("IM2COLMATMUL:AARCH64_F16_K8X24X1");
  1044. #elif MEGDNN_ARMV7
  1045. cb("IM2COLMATMUL:AARCH32_F16_K4X16X1");
  1046. #endif
  1047. #undef cb
  1048. }
  1049. #endif
  1050. void checker_conv_bias_mul_int8x8x32(std::vector<conv_bias::TestArg> args,
  1051. Handle* handle, const char* algo_name) {
  1052. using namespace conv_bias;
  1053. Checker<ConvBias> checker(handle);
  1054. checker.set_before_exec_callback(
  1055. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  1056. checker.set_dtype(0, dtype::Int8());
  1057. checker.set_dtype(1, dtype::Int8());
  1058. checker.set_dtype(2, dtype::Int32());
  1059. checker.set_dtype(4, dtype::Int32());
  1060. for (auto&& arg : args) {
  1061. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  1062. }
  1063. UniformIntRNG rng{-50, 50};
  1064. for (auto&& arg : args) {
  1065. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  1066. .set_dtype(1, dtype::QuantizedS8(2.5f))
  1067. .set_dtype(2, dtype::QuantizedS32(6.25f))
  1068. .set_dtype(4, {})
  1069. .set_rng(0, &rng)
  1070. .set_rng(1, &rng)
  1071. .set_rng(2, &rng)
  1072. .set_param(arg.param)
  1073. .execs({arg.src, arg.filter, {}, {}, {}});
  1074. }
  1075. }
  1076. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1077. #if !__ARM_FEATURE_DOTPROD
  1078. TEST_F(ARM_COMMON, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44) {
  1079. using namespace conv_bias;
  1080. std::vector<conv_bias::TestArg> args =
  1081. get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true);
  1082. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1083. #if MEGDNN_AARCH64
  1084. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1085. #else
  1086. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1087. #endif
  1088. #undef cb
  1089. }
  1090. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_MULTI) {
  1091. using namespace conv_bias;
  1092. std::vector<conv_bias::TestArg> args =
  1093. get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true);
  1094. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1095. #if MEGDNN_AARCH64
  1096. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1097. #else
  1098. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1099. #endif
  1100. #undef cb
  1101. }
  1102. TEST_F(ARM_COMMON, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44) {
  1103. UniformIntRNG rng{-50, 50};
  1104. #define cb(name) \
  1105. checker_conv_bias(get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1), \
  1106. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1107. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1108. dtype::QuantizedS8(60.25f), name);
  1109. float epsilon = 0.001;
  1110. #if MEGDNN_AARCH64
  1111. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1112. #else
  1113. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1114. #endif
  1115. #undef cb
  1116. }
  1117. TEST_F(ARM_COMMON_MULTI_THREADS,
  1118. CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_MULTI) {
  1119. UniformIntRNG rng{-50, 50};
  1120. #define cb(name) \
  1121. checker_conv_bias(get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1), \
  1122. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1123. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1124. dtype::QuantizedS8(60.25f), name);
  1125. float epsilon = 0.001;
  1126. #if MEGDNN_AARCH64
  1127. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1128. #else
  1129. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1130. #endif
  1131. #undef cb
  1132. }
  1133. #if MEGDNN_AARCH64
  1134. TEST_F(ARM_COMMON_MULTI_THREADS,
  1135. CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_FUSE) {
  1136. UniformIntRNG rng{-50, 50};
  1137. #define cb(name) \
  1138. checker_conv_bias(get_nchw44_conv_bias_args({3}, 1), handle(), &rng, \
  1139. epsilon, dtype::QuantizedS8(2.5f), \
  1140. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1141. dtype::QuantizedS8(60.25f), name);
  1142. float epsilon = 0.001;
  1143. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1144. #undef cb
  1145. }
  1146. #endif
  1147. #endif
  1148. #endif
  1149. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
  1150. using namespace conv_bias;
  1151. std::vector<conv_bias::TestArg> args =
  1152. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true);
  1153. std::vector<conv_bias::TestArg> args1 =
  1154. get_conv_bias_args({1}, 2, false, true, true);
  1155. args.insert(args.begin(), args1.begin(), args1.end());
  1156. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1157. #if MEGDNN_AARCH64
  1158. #if __ARM_FEATURE_DOTPROD
  1159. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD");
  1160. #else
  1161. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8");
  1162. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16");
  1163. #endif
  1164. #elif MEGDNN_ARMV7
  1165. #if __ARM_FEATURE_DOTPROD
  1166. cb("IM2COLMATMUL:AARCH32_INT8_K6X8X4");
  1167. #endif
  1168. cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8");
  1169. #endif
  1170. #if MEGDNN_ARMV7
  1171. cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X2X16");
  1172. #endif
  1173. #undef cb
  1174. }
  1175. /***************************** Conv1x1 Algo Test ***********************/
  1176. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) {
  1177. using namespace conv_bias;
  1178. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, false);
  1179. #if MEGDNN_AARCH64
  1180. check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32K8X12X1:24");
  1181. #elif MEGDNN_ARMV7
  1182. check_conv_bias(args, handle(), "CONV1x1:ARMV7_F32:48");
  1183. #endif
  1184. }
  1185. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  1186. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16) {
  1187. using namespace conv_bias;
  1188. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, false);
  1189. NormalRNG rng(1);
  1190. #if MEGDNN_AARCH64
  1191. checker_conv_bias(args, handle(), &rng, 0.03, dtype::Float16{},
  1192. dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
  1193. "CONV1x1:AARCH64_F16_K8X24X1:48");
  1194. #elif MEGDNN_ARMV7
  1195. checker_conv_bias(args, handle(), &rng, 0.03, dtype::Float16{},
  1196. dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
  1197. "CONV1x1:AARCH32_F16_K4X16X1:24");
  1198. #endif
  1199. }
  1200. #endif
  1201. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM) {
  1202. UniformIntRNG rng{-50, 50};
  1203. float epsilon = 0.001;
  1204. #define cb(name) \
  1205. checker_conv_bias(get_conv_bias_1x1_args(false, false, true, true), \
  1206. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1207. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1208. dtype::QuantizedS8(60.25f), name);
  1209. #if MEGDNN_AARCH64
  1210. #if __ARM_FEATURE_DOTPROD
  1211. cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:24");
  1212. #else
  1213. cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24");
  1214. cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:48");
  1215. #endif
  1216. #elif MEGDNN_ARMV7
  1217. epsilon = 1;
  1218. cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:48");
  1219. #endif
  1220. #undef cb
  1221. }
  1222. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1223. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) {
  1224. NormalRNG rng(128.f);
  1225. #define cb(name) \
  1226. checker_conv_bias(get_conv_bias_1x1_args(false, false, true, true), \
  1227. handle(), &rng, epsilon, \
  1228. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1229. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1230. dtype::QuantizedS32(1.2 * 1.3), \
  1231. dtype::Quantized8Asymm(50.3f, (uint8_t)120), name);
  1232. float epsilon = 0.001;
  1233. #if MEGDNN_AARCH64
  1234. #if __ARM_FEATURE_DOTPROD
  1235. cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:48");
  1236. #else
  1237. cb("CONV1x1:AARCH64_QUINT8_K8X8X8:24");
  1238. #endif
  1239. #elif MEGDNN_ARMV7
  1240. epsilon = 1;
  1241. cb("CONV1x1:ARMV7_QUINT8_K4X8X8:48");
  1242. #endif
  1243. #undef cb
  1244. }
  1245. #endif
  1246. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1247. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32) {
  1248. UniformIntRNG rng{-50, 50};
  1249. float epsilon = 0.001;
  1250. #define cb(name) \
  1251. checker_conv_bias(get_conv_bias_1x1_args(true, true), handle(), &rng, \
  1252. epsilon, dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1253. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1254. dtype::QuantizedS32(1.2 * 1.3), {}, name);
  1255. #if MEGDNN_AARCH64
  1256. #if __ARM_FEATURE_DOTPROD
  1257. cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:24");
  1258. #else
  1259. cb("CONV1x1:AARCH64_QUINT8_K8X8X8:48");
  1260. #endif
  1261. #elif MEGDNN_ARMV7
  1262. #if __ARM_FEATURE_DOTPROD
  1263. cb("CONV1x1:AARCH32_QUINT8_K4X8X4:48");
  1264. #endif
  1265. cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24");
  1266. #endif
  1267. #undef cb
  1268. }
  1269. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) {
  1270. UniformIntRNG rng{-50, 50};
  1271. float epsilon = 0.001;
  1272. #define cb(name) \
  1273. checker_conv_bias(get_conv_bias_1x1_args(true, true), handle(), &rng, \
  1274. epsilon, dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, \
  1275. dtype::Int16{}, name);
  1276. #if MEGDNN_AARCH64
  1277. cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24");
  1278. cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24");
  1279. #elif MEGDNN_ARMV7
  1280. cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24");
  1281. cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48");
  1282. #endif
  1283. cb("CONV1x1:ARM_COMMON_INT8X8X16:48");
  1284. #undef cb
  1285. }
  1286. #endif
  1287. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) {
  1288. using namespace conv_bias;
  1289. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true);
  1290. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1291. #if MEGDNN_AARCH64
  1292. #if __ARM_FEATURE_DOTPROD
  1293. cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:48");
  1294. #else
  1295. cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24");
  1296. cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:24");
  1297. #endif
  1298. #elif MEGDNN_ARMV7
  1299. #if __ARM_FEATURE_DOTPROD
  1300. cb("CONV1x1:AARCH32_INT8_K6X8X4:48");
  1301. #endif
  1302. cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:24");
  1303. #endif
  1304. #if MEGDNN_ARMV7
  1305. cb("CONV1x1:ARMV7_INT8X8X32_K4X2X16:48");
  1306. #endif
  1307. #undef cb
  1308. }
  1309. #ifndef __ARM_FEATURE_DOTPROD
  1310. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) {
  1311. using namespace conv_bias;
  1312. std::vector<conv_bias::TestArg> args =
  1313. get_nchw44_conv_bias_args({1}, 1, true, true, true);
  1314. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1315. #if MEGDNN_AARCH64
  1316. cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24");
  1317. #elif MEGDNN_ARMV7
  1318. cb("CONV1x1:ARMV7_INT8X8X32_MK4_4X2X16:24");
  1319. #endif
  1320. #undef cb
  1321. UniformIntRNG rng{-50, 50};
  1322. float epsilon = 0.001;
  1323. #define cb(name) \
  1324. checker_conv_bias(get_nchw44_conv_bias_args({1}, 1, true, false, false), \
  1325. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1326. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1327. dtype::QuantizedS8(60.25f), name);
  1328. #if MEGDNN_AARCH64
  1329. cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24");
  1330. #elif MEGDNN_ARMV7
  1331. cb("CONV1x1:ARMV7_INT8X8X32_MK4_4X2X16:24");
  1332. #endif
  1333. #undef cb
  1334. }
  1335. #endif
  1336. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台