You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bias_multi_thread.cpp 69 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638
  1. /**
  2. * \file dnn/test/arm_common/conv_bias_multi_thread.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "test/arm_common/fixture.h"
  13. #include "test/common/benchmarker.h"
  14. #include "test/common/conv_bias.h"
  15. using namespace megdnn;
  16. using namespace test;
  17. using namespace conv_bias;
  18. std::vector<conv_bias::TestArg> get_int8_quint8_conv_bias_args(
  19. std::vector<size_t> kernel, size_t stride, bool no_pad, bool no_bias,
  20. bool no_nonlinemode) {
  21. using namespace conv_bias;
  22. using Param = param::ConvBias;
  23. using NLMode = param::ConvBias::NonlineMode;
  24. std::vector<TestArg> args;
  25. auto pack = [&](size_t n, size_t oc, size_t ic, size_t w, size_t h,
  26. size_t kernel, size_t stride, NLMode nlmode) {
  27. Param param;
  28. param.stride_h = stride;
  29. param.stride_w = stride;
  30. if (!no_pad) {
  31. param.pad_h = kernel / 2;
  32. param.pad_w = kernel / 2;
  33. } else {
  34. param.pad_h = 0;
  35. param.pad_w = 0;
  36. }
  37. param.nonlineMode = nlmode;
  38. args.emplace_back(param, TensorShape{n, ic, h, w},
  39. TensorShape{oc, ic, kernel, kernel}, TensorShape{});
  40. if (!no_bias) {
  41. args.emplace_back(param, TensorShape{n, ic, h, w},
  42. TensorShape{oc, ic, kernel, kernel},
  43. TensorShape{1, oc, 1, 1});
  44. }
  45. };
  46. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  47. if (!no_nonlinemode) {
  48. nonlinemode.emplace_back(NLMode::RELU);
  49. nonlinemode.emplace_back(NLMode::H_SWISH);
  50. }
  51. for (size_t n : {1, 2}) {
  52. for (auto nlmode : nonlinemode) {
  53. for (size_t ic : {1, 3, 7}) {
  54. for (size_t oc : {1, 3, 7}) {
  55. for (size_t size : {4, 6, 8, 14, 16, 18}) {
  56. for (size_t kern : kernel) {
  57. pack(n, oc, ic, size, size, kern, stride, nlmode);
  58. }
  59. }
  60. }
  61. }
  62. }
  63. }
  64. return args;
  65. }
  66. std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
  67. std::vector<size_t> kernel_vec, size_t stride, bool no_pad = false,
  68. bool no_bias = false, bool no_nonlinemode = false,
  69. bool is_input_nchw = false, bool support_full_bias = false,
  70. bool support_sigmoid = false) {
  71. using namespace conv_bias;
  72. using NLMode = param::ConvBias::NonlineMode;
  73. std::vector<TestArg> args;
  74. auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w,
  75. size_t kernel, size_t stride, size_t group, NLMode nlmode,
  76. megdnn::BiasMode bias_mode, int any_pad = -1) {
  77. constexpr int pack_c = 4;
  78. const size_t pad = any_pad >= 0 ? any_pad : kernel / 2;
  79. auto oc_per_group = oc / group;
  80. auto ic_per_group = ic / group;
  81. bool ok_group = (oc % group == 0 && ic % group == 0) &&
  82. oc_per_group % pack_c == 0 && oc_per_group > 0 &&
  83. ic_per_group > 0;
  84. bool nchw_disable = group > 1 || ic_per_group >= 4;
  85. bool nchw44_disable = ic_per_group % pack_c != 0;
  86. bool invalid_pad = (w + 2 * pad < kernel) || (h + 2 * pad < kernel);
  87. if (!(ok_group) || invalid_pad) {
  88. return;
  89. }
  90. if ((is_input_nchw && nchw_disable) ||
  91. (!is_input_nchw && nchw44_disable)) {
  92. return;
  93. }
  94. size_t kernel_h = kernel;
  95. size_t kernel_w = kernel;
  96. param::ConvBias param;
  97. param.format = param::ConvBias::Format::NCHW44;
  98. param.stride_h = stride;
  99. param.stride_w = stride;
  100. param.pad_h = pad;
  101. param.pad_w = pad;
  102. param.nonlineMode = nlmode;
  103. auto src_tensor_shape = TensorShape{n, ic / pack_c, h, w, pack_c};
  104. auto weight_tensor_shape = TensorShape{
  105. oc / pack_c, ic / pack_c, kernel_h, kernel_w, pack_c, pack_c};
  106. auto bias_tensor_shape = TensorShape{};
  107. if (bias_mode == megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) {
  108. bias_tensor_shape = {1, oc / pack_c, 1, 1, pack_c};
  109. } else if (bias_mode == megdnn::BiasMode::BIAS) {
  110. bias_tensor_shape = {n, oc / pack_c,
  111. (h + 2 * pad - kernel) / stride + 1,
  112. (w + 2 * pad - kernel) / stride + 1, pack_c};
  113. }
  114. if (group == 1) {
  115. param.sparse = param::ConvBias::Sparse::DENSE;
  116. } else if (group > 1 && ic / group == 1 && oc / group == 1) {
  117. megdnn_assert(0, "not support channel wise");
  118. param.sparse = param::ConvBias::Sparse::GROUP;
  119. weight_tensor_shape = TensorShape{group / pack_c, 1, 1,
  120. kernel_h, kernel_w, pack_c};
  121. } else if (group > 1 && oc_per_group % pack_c == 0 && oc / group > 0 &&
  122. ic_per_group % pack_c == 0 && ic / group > 0) {
  123. param.sparse = param::ConvBias::Sparse::GROUP;
  124. weight_tensor_shape = TensorShape{group,
  125. oc_per_group / pack_c,
  126. ic_per_group / pack_c,
  127. kernel_h,
  128. kernel_w,
  129. pack_c,
  130. pack_c};
  131. }
  132. if (is_input_nchw) {
  133. src_tensor_shape = TensorShape{n, ic, h, w};
  134. weight_tensor_shape =
  135. TensorShape{oc / pack_c, kernel_h, kernel_w, ic, pack_c};
  136. }
  137. args.emplace_back(param, src_tensor_shape, weight_tensor_shape,
  138. bias_tensor_shape);
  139. };
  140. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  141. if (!no_nonlinemode) {
  142. nonlinemode.emplace_back(NLMode::RELU);
  143. nonlinemode.emplace_back(NLMode::H_SWISH);
  144. }
  145. if (support_sigmoid) {
  146. nonlinemode.emplace_back(NLMode::SIGMOID);
  147. }
  148. std::vector<megdnn::BiasMode> bias_mode = {
  149. megdnn::BiasMode::BROADCAST_CHANNEL_BIAS};
  150. if (no_bias) {
  151. bias_mode.emplace_back(megdnn::BiasMode::NO_BIAS);
  152. }
  153. if (support_full_bias) {
  154. bias_mode.emplace_back(megdnn::BiasMode::BIAS);
  155. }
  156. for (auto bias : bias_mode)
  157. for (auto nlmode : nonlinemode)
  158. for (size_t n : {1, 2})
  159. for (size_t kernel : kernel_vec)
  160. for (size_t oc : {4, 12, 32})
  161. for (size_t ic : {1, 3, 4, 12, 32})
  162. for (size_t h : {3, 5, 12})
  163. for (size_t w : {7, 16, 23}) {
  164. for (size_t group = 1;
  165. group <= std::min(oc, ic); ++group) {
  166. pack(n, oc, ic, h, w, kernel, stride,
  167. group, nlmode, bias);
  168. }
  169. }
  170. return args;
  171. }
  172. std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args(
  173. std::vector<size_t> kernel, size_t stride, bool no_bias,
  174. bool no_nonlinemode, bool no_full_bias) {
  175. using namespace conv_bias;
  176. using Param = param::ConvBias;
  177. using NLMode = param::ConvBias::NonlineMode;
  178. std::vector<TestArg> args;
  179. auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
  180. size_t stride, NLMode nlmode, bool pad) {
  181. Param param;
  182. param.stride_h = stride;
  183. param.stride_w = stride;
  184. if (pad) {
  185. param.pad_h = kernel / 2;
  186. param.pad_w = kernel / 2;
  187. } else {
  188. param.pad_h = 0;
  189. param.pad_w = 0;
  190. }
  191. param.nonlineMode = nlmode;
  192. param.format = param::ConvBias::Format::NCHW44;
  193. param.sparse = param::ConvBias::Sparse::GROUP;
  194. args.emplace_back(param, TensorShape{n, group, h, w, 4},
  195. TensorShape{group, 1, 1, kernel, kernel, 4},
  196. TensorShape{});
  197. if (!no_bias) {
  198. args.emplace_back(param, TensorShape{n, group, h, w, 4},
  199. TensorShape{group, 1, 1, kernel, kernel, 4},
  200. TensorShape{1, group, 1, 1, 4});
  201. }
  202. if (!no_full_bias) {
  203. args.emplace_back(
  204. param, TensorShape{n, group, h, w, 4},
  205. TensorShape{group, 1, 1, kernel, kernel, 4},
  206. TensorShape{n, group,
  207. (h + 2 * param.pad_w - kernel) / stride + 1,
  208. (w + 2 * param.pad_w - kernel) / stride + 1,
  209. 4});
  210. }
  211. };
  212. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  213. if (!no_nonlinemode) {
  214. nonlinemode.emplace_back(NLMode::RELU);
  215. nonlinemode.emplace_back(NLMode::H_SWISH);
  216. }
  217. for (size_t n : {1, 2}) {
  218. for (auto nlmode : nonlinemode) {
  219. for (bool pad : {true}) {
  220. for (size_t group : {1, 2, 4, 7, 128}) {
  221. for (size_t size : {4, 6, 7, 9, 15, 40}) {
  222. for (size_t kern : kernel) {
  223. pack(n, group, size, size, kern, stride, nlmode,
  224. pad);
  225. }
  226. }
  227. }
  228. }
  229. for (bool pad : {false}) {
  230. for (size_t group : {1, 2, 7, 128}) {
  231. for (size_t size : {7, 9, 15, 40}) {
  232. for (size_t kern : kernel) {
  233. pack(n, group, size, size, kern, stride, nlmode,
  234. pad);
  235. }
  236. }
  237. }
  238. }
  239. }
  240. }
  241. return args;
  242. }
  243. void checker_conv_bias_qint8x8x8(std::vector<conv_bias::TestArg> args,
  244. Handle* handle, const char* algo_name) {
  245. Checker<ConvBias> checker(handle);
  246. checker.set_before_exec_callback(
  247. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  248. #if MEGDNN_ARMV7
  249. checker.set_epsilon(1);
  250. #endif
  251. UniformIntRNG rng{-50, 50};
  252. checker.set_dtype(0, dtype::QuantizedS8(0.41113496f))
  253. .set_dtype(1, dtype::QuantizedS8(0.01887994f))
  254. .set_dtype(2, dtype::QuantizedS32(0.41113496f * 0.01887994f))
  255. .set_dtype(4, dtype::QuantizedS8(0.49550694f))
  256. .set_rng(0, &rng)
  257. .set_rng(1, &rng)
  258. .set_rng(2, &rng);
  259. for (auto&& arg : args) {
  260. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  261. }
  262. }
  263. void checker_conv_bias_qint8x8x32(std::vector<conv_bias::TestArg> args,
  264. Handle* handle, const char* algo_name) {
  265. Checker<ConvBias> checker(handle);
  266. UniformIntRNG rng{-50, 50};
  267. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  268. .set_dtype(1, dtype::QuantizedS8(2.5f))
  269. .set_dtype(2, dtype::QuantizedS32(6.25f))
  270. .set_dtype(4, {});
  271. checker.set_before_exec_callback(
  272. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  273. for (auto&& arg : args) {
  274. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  275. }
  276. }
  277. void checker_conv_bias_quint8x8x8(std::vector<conv_bias::TestArg> args,
  278. Handle* handle, const char* algo_name) {
  279. Checker<ConvBias> checker(handle);
  280. checker.set_before_exec_callback(
  281. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  282. UniformIntRNG rng(0, 255);
  283. checker.set_dtype(0, dtype::Quantized8Asymm(0.2f, 100))
  284. .set_dtype(1, dtype::Quantized8Asymm(0.2f, 120))
  285. .set_dtype(2, dtype::QuantizedS32(0.04f))
  286. .set_dtype(4, dtype::Quantized8Asymm(1.4f, 110))
  287. .set_rng(0, &rng)
  288. .set_rng(1, &rng)
  289. .set_rng(2, &rng);
  290. for (auto&& arg : args) {
  291. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  292. }
  293. }
  294. void checker_conv_bias_quint8x8x32(std::vector<conv_bias::TestArg> args,
  295. Handle* handle, const char* algo_name) {
  296. Checker<ConvBias> checker(handle);
  297. checker.set_before_exec_callback(
  298. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  299. NormalRNG rng(128.f);
  300. checker.set_rng(0, &rng).set_rng(1, &rng);
  301. checker.set_dtype(0, dtype::Quantized8Asymm(1.2f, (uint8_t)127))
  302. .set_dtype(1, dtype::Quantized8Asymm(1.3f, (uint8_t)129))
  303. .set_dtype(2, dtype::QuantizedS32(1.2 * 1.3))
  304. .set_dtype(4, {});
  305. for (auto&& arg : args) {
  306. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  307. }
  308. }
  309. void checker_conv_bias_int8x8x32_multi(std::vector<conv_bias::TestArg> args,
  310. Handle* handle, const char* algo_name) {
  311. Checker<ConvBias> checker(handle);
  312. checker.set_before_exec_callback(
  313. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  314. checker.set_dtype(0, dtype::Int8());
  315. checker.set_dtype(1, dtype::Int8());
  316. checker.set_dtype(2, dtype::Int32());
  317. checker.set_dtype(4, dtype::Int32());
  318. for (auto&& arg : args) {
  319. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  320. }
  321. }
  322. /**********************************F32 direct************************/
  323. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_LARGE_GROUP) {
  324. check_conv_bias(
  325. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  326. handle(), "F32DIRECT_LARGE_GROUP");
  327. }
  328. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_SMALL_GROUP) {
  329. check_conv_bias(
  330. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  331. handle(), "F32DIRECT_SMALL_GROUP");
  332. }
  333. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_1) {
  334. check_conv_bias(get_nchw44_conv_bias_args({2, 7}, 1, false, false, false,
  335. false, true, true),
  336. handle(), "F32_CONV_NCHW44_DIRECT");
  337. }
  338. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_2) {
  339. check_conv_bias(get_nchw44_conv_bias_args({3, 5}, 1, false, false, false,
  340. false, true, true),
  341. handle(), "F32_CONV_NCHW44_DIRECT");
  342. }
  343. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) {
  344. check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false,
  345. false, false, true, true),
  346. handle(), "F32_CONV_NCHW44_DIRECT");
  347. }
  348. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_LARGE_GROUP) {
  349. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  350. handle(), "F32STRD1_LARGE_GROUP");
  351. }
  352. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_SMALL_GROUP) {
  353. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  354. handle(), "F32STRD1_SMALL_GROUP");
  355. }
  356. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_LARGE_GROUP) {
  357. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  358. handle(), "F32STRD2_LARGE_GROUP");
  359. }
  360. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) {
  361. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  362. handle(), "F32STRD2_SMALL_GROUP");
  363. }
  364. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32) {
  365. check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false,
  366. false, true),
  367. handle(), "F32_CONV_NCHW_NCHW44");
  368. }
  369. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44) {
  370. check_conv_bias(
  371. get_nchw44_channel_wise_args({2, 3, 5}, 1, false, false, false),
  372. handle(), "F32_CHANNEL_WISE_NCHW44");
  373. }
  374. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP32_NCHW44) {
  375. check_conv_bias(
  376. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, false),
  377. handle(), "F32_CHANNEL_WISE_NCHW44");
  378. }
  379. /**********************************F16 direct************************/
  380. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  381. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_LARGE_GROUP) {
  382. NormalRNG rng(1);
  383. checker_conv_bias_f16(
  384. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  385. handle(), rng, "F16DIRECT_LARGE_GROUP", 0.03);
  386. }
  387. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_SMALL_GROUP) {
  388. NormalRNG rng(1);
  389. checker_conv_bias_f16(
  390. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  391. handle(), rng, "F16DIRECT_SMALL_GROUP", 0.03);
  392. }
  393. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_LARGE_GROUP) {
  394. NormalRNG rng(1);
  395. checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false),
  396. handle(), rng, "F16STRD1_LARGE_GROUP", 0.03);
  397. }
  398. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_SMALL_GROUP) {
  399. NormalRNG rng(1);
  400. checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false),
  401. handle(), rng, "F16STRD1_SMALL_GROUP", 0.03);
  402. }
  403. #endif
  404. /**********************************algo 8816 direct************************/
  405. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_LARGE_GROUP) {
  406. checker_conv_bias_int8x8x16(
  407. get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(),
  408. "I8816DIRECT_LARGE_GROUP");
  409. }
  410. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_SMALL_GROUP) {
  411. checker_conv_bias_int8x8x16(
  412. get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(),
  413. "I8816DIRECT_SMALL_GROUP");
  414. }
  415. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_LARGE_GROUP) {
  416. checker_conv_bias_int8x8x16(
  417. get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
  418. "I8816STRD2_LARGE_GROUP");
  419. }
  420. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_SMALL_GROUP) {
  421. checker_conv_bias_int8x8x16(
  422. get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
  423. "I8816STRD2_SMALL_GROUP");
  424. }
  425. /**********************************algo 8-8-32 direct************************/
  426. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_LARGE_GROUP) {
  427. checker_conv_bias_int8x8x32_multi(
  428. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  429. "S8STRD1_LARGE_GROUP");
  430. }
  431. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_SMALL_GROUP) {
  432. checker_conv_bias_int8x8x32_multi(
  433. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  434. "S8STRD1_SMALL_GROUP");
  435. }
  436. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_LARGE_GROUP) {
  437. checker_conv_bias_int8x8x32_multi(
  438. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  439. "S8STRD2_LARGE_GROUP");
  440. }
  441. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_SMALL_GROUP) {
  442. checker_conv_bias_int8x8x32_multi(
  443. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  444. "S8STRD2_SMALL_GROUP");
  445. }
  446. TEST_F(ARM_COMMON_MULTI_THREADS,
  447. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT1_NCHW44) {
  448. checker_conv_bias_int8x8x32_multi(
  449. get_nchw44_channel_wise_args({2, 3, 5}, 1, false, true, true),
  450. handle(), "S8_CHAN_WISE_STRD1_NCHW44");
  451. }
  452. TEST_F(ARM_COMMON_MULTI_THREADS,
  453. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT2_NCHW44) {
  454. checker_conv_bias_int8x8x32_multi(
  455. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, true, true),
  456. handle(), "S8_CHAN_WISE_STRD2_NCHW44");
  457. }
  458. /********************************qint8 direct******************************/
  459. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_LARGE_GROUP) {
  460. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  461. {2, 3, 5, 7}, 1, false, false, false),
  462. handle(), "S8STRD1_LARGE_GROUP");
  463. }
  464. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_SMALL_GROUP) {
  465. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  466. {2, 3, 5, 7}, 1, false, false, false),
  467. handle(), "S8STRD1_SMALL_GROUP");
  468. }
  469. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_LARGE_GROUP) {
  470. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  471. {2, 3, 5, 7}, 2, false, false, false),
  472. handle(), "S8STRD2_LARGE_GROUP");
  473. }
  474. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_SMALL_GROUP) {
  475. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  476. {2, 3, 5, 7}, 2, false, false, false),
  477. handle(), "S8STRD2_SMALL_GROUP");
  478. }
  479. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) {
  480. checker_conv_bias_qint8x8x8(
  481. get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  482. handle(), "S8_NCHW44_DIRECT_STRD1");
  483. }
  484. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) {
  485. checker_conv_bias_qint8x8x8(
  486. get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  487. handle(), "S8_NCHW44_DIRECT_STRD2");
  488. }
  489. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) {
  490. checker_conv_bias_qint8x8x8(
  491. get_nchw44_channel_wise_args({2, 3, 5}, 1, false, false, true),
  492. handle(), "S8_CHAN_WISE_STRD1_NCHW44");
  493. }
  494. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) {
  495. checker_conv_bias_qint8x8x8(
  496. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, true),
  497. handle(), "S8_CHAN_WISE_STRD2_NCHW44");
  498. }
  499. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44) {
  500. checker_conv_bias_qint8x8x8(
  501. get_nchw44_conv_bias_args({3, 5, 7}, 2, false, false, false, true),
  502. handle(), "S8_CONV_NCHW_NCHW44");
  503. }
  504. /*****************************quint8 direct****************************/
  505. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_LARGE_GROUP) {
  506. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  507. {2, 3, 5, 7}, 1, false, false, false),
  508. handle(), "QU8STRD1_LARGE_GROUP");
  509. }
  510. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_SMALL_GROUP) {
  511. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  512. {2, 3, 5, 7}, 1, false, false, false),
  513. handle(), "QU8STRD1_SMALL_GROUP");
  514. }
  515. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_LARGE_GROUP) {
  516. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  517. {2, 3, 5, 7}, 2, false, false, false),
  518. handle(), "QU8STRD2_LARGE_GROUP");
  519. }
  520. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) {
  521. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  522. {2, 3, 5, 7}, 2, false, false, false),
  523. handle(), "QU8STRD2_SMALL_GROUP");
  524. }
  525. /****************************dot qint8 direct*************************/
  526. #if __ARM_FEATURE_DOTPROD
  527. TEST_F(ARM_COMMON_MULTI_THREADS,
  528. CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) {
  529. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  530. {2, 3, 5, 7}, 1, false, false, false),
  531. handle(), "ARMDOTS8STRD1_LARGE_GROUP");
  532. }
  533. TEST_F(ARM_COMMON_MULTI_THREADS,
  534. CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) {
  535. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  536. {2, 3, 5, 7}, 1, false, false, false),
  537. handle(), "ARMDOTS8STRD1_SMALL_GROUP");
  538. }
  539. TEST_F(ARM_COMMON_MULTI_THREADS,
  540. CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) {
  541. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  542. {2, 3, 5, 7}, 2, false, false, false),
  543. handle(), "ARMDOTS8STRD2_LARGE_GROUP");
  544. }
  545. TEST_F(ARM_COMMON_MULTI_THREADS,
  546. CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) {
  547. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  548. {2, 3, 5, 7}, 2, false, false, false),
  549. handle(), "ARMDOTS8STRD2_SMALL_GROUP");
  550. }
  551. /****************************dot 8-8-32 direct*************************/
  552. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_LARGE_GROUP) {
  553. checker_conv_bias_qint8x8x32(
  554. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  555. "ARMDOTS8STRD1_LARGE_GROUP");
  556. }
  557. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_SMALL_GROUP) {
  558. checker_conv_bias_qint8x8x32(
  559. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  560. "ARMDOTS8STRD1_SMALL_GROUP");
  561. }
  562. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_LARGE_GROUP) {
  563. checker_conv_bias_qint8x8x32(
  564. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  565. "ARMDOTS8STRD2_LARGE_GROUP");
  566. }
  567. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_SMALL_GROUP) {
  568. checker_conv_bias_qint8x8x32(
  569. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  570. "ARMDOTS8STRD2_SMALL_GROUP");
  571. }
  572. /******************************dot quint8*****************************/
  573. TEST_F(ARM_COMMON_MULTI_THREADS,
  574. CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) {
  575. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  576. {2, 3, 5, 7}, 1, false, false, false),
  577. handle(), "ARMDOTU8STRD1_LARGE_GROUP");
  578. }
  579. TEST_F(ARM_COMMON_MULTI_THREADS,
  580. CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) {
  581. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  582. {2, 3, 5, 7}, 1, false, false, false),
  583. handle(), "ARMDOTU8STRD1_SMALL_GROUP");
  584. }
  585. TEST_F(ARM_COMMON_MULTI_THREADS,
  586. CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) {
  587. checker_conv_bias_quint8x8x8(
  588. get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false),
  589. handle(), "ARMDOTU8STRD2_LARGE_GROUP");
  590. }
  591. TEST_F(ARM_COMMON_MULTI_THREADS,
  592. CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) {
  593. checker_conv_bias_quint8x8x8(
  594. get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false),
  595. handle(), "ARMDOTU8STRD2_SMALL_GROUP");
  596. }
  597. /******************************dot quint8x8x32***********************/
  598. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_LARGE_GROUP) {
  599. checker_conv_bias_quint8x8x32(
  600. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  601. "ARMDOTU8STRD1_LARGE_GROUP");
  602. }
  603. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_SMALL_GROUP) {
  604. checker_conv_bias_quint8x8x32(
  605. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  606. "ARMDOTU8STRD1_SMALL_GROUP");
  607. }
  608. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_LARGE_GROUP) {
  609. checker_conv_bias_quint8x8x32(
  610. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  611. "ARMDOTU8STRD2_LARGE_GROUP");
  612. }
  613. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_SMALL_GROUP) {
  614. checker_conv_bias_quint8x8x32(
  615. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  616. "ARMDOTU8STRD2_SMALL_GROUP");
  617. }
  618. #endif
  619. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4) {
  620. using namespace conv_bias;
  621. std::vector<TestArg> args = get_winograd_mk_packed_args();
  622. Checker<ConvBiasForward> checker(handle());
  623. check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4);
  624. }
  625. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4_NCHW44) {
  626. using namespace conv_bias;
  627. std::vector<TestArg> args = get_nchw44_conv_bias_args({3}, 1);
  628. Checker<ConvBiasForward> checker(handle());
  629. check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4,
  630. param::ConvBias::Format::NCHW44);
  631. }
  632. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63) {
  633. using namespace conv_bias;
  634. std::vector<TestArg> args = get_winograd_args(3);
  635. Checker<ConvBiasForward> checker(handle());
  636. check_winograd("1:6:32", checker, args);
  637. }
  638. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) {
  639. using namespace conv_bias;
  640. std::vector<TestArg> args = get_winograd_mk_packed_args();
  641. Checker<ConvBiasForward> checker(handle());
  642. check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4);
  643. }
  644. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44) {
  645. using namespace conv_bias;
  646. std::vector<TestArg> args = get_nchw44_conv_bias_args({3}, 1);
  647. Checker<ConvBiasForward> checker(handle());
  648. check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4,
  649. param::ConvBias::Format::NCHW44);
  650. }
  651. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54) {
  652. using namespace conv_bias;
  653. std::vector<TestArg> args = get_winograd_args(4);
  654. Checker<ConvBiasForward> checker(handle());
  655. check_winograd("1:5:32", checker, args);
  656. }
  657. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45) {
  658. using namespace conv_bias;
  659. std::vector<TestArg> args = get_winograd_args(5);
  660. Checker<ConvBiasForward> checker(handle());
  661. check_winograd("1:4:32", checker, args);
  662. }
  663. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) {
  664. using namespace conv_bias;
  665. std::vector<TestArg> args = get_winograd_args(3);
  666. Checker<ConvBiasForward> checker(handle());
  667. auto extra_impl = [](const TensorNDArray& tensors, uint32_t m,
  668. param::ConvBias param, Handle* handle) {
  669. megdnn_assert(param.format == param::ConvBias::Format::NCHW);
  670. auto winograd_preprocess_opr =
  671. handle->create_operator<WinogradFilterPreprocess>();
  672. winograd_preprocess_opr->param().output_block_size = m;
  673. TensorLayout filter_transform_layout;
  674. winograd_preprocess_opr->deduce_layout(tensors[1].layout,
  675. filter_transform_layout);
  676. size_t winograd_preprocess_workspace_in_bytes =
  677. winograd_preprocess_opr->get_workspace_in_bytes(
  678. tensors[1].layout, filter_transform_layout);
  679. auto conv_bias_opr = handle->create_operator<ConvBias>();
  680. conv_bias_opr->param() = param;
  681. conv_bias_opr->param().format = param::ConvBias::Format::NCHW_WINOGRAD;
  682. conv_bias_opr->param().output_block_size = m;
  683. size_t conv_bias_workspace_in_bytes =
  684. conv_bias_opr->get_workspace_in_bytes(
  685. tensors[0].layout, filter_transform_layout,
  686. tensors[2].layout, tensors[3].layout, tensors[4].layout,
  687. nullptr);
  688. WorkspaceBundle wb(nullptr, {filter_transform_layout.span().dist_byte(),
  689. conv_bias_workspace_in_bytes,
  690. winograd_preprocess_workspace_in_bytes});
  691. wb.set(malloc(wb.total_size_in_bytes()));
  692. TensorND filter_transform_tensor(wb.get(0),
  693. std::move(filter_transform_layout));
  694. winograd_preprocess_opr->exec(tensors[1], filter_transform_tensor,
  695. wb.get_workspace(2));
  696. conv_bias_opr->exec(tensors[0], filter_transform_tensor, tensors[2],
  697. tensors[3], tensors[4], nullptr,
  698. wb.get_workspace(1));
  699. free(wb.ptr());
  700. };
  701. auto run = [&checker, &extra_impl](
  702. Handle* handle, const std::vector<TestArg>& args,
  703. const std::vector<size_t>& out_size, DType A_dtype,
  704. DType B_dtype, DType C_dtype, DType D_dtype,
  705. const float eps) {
  706. for (auto&& arg : args) {
  707. for (uint32_t m : out_size) {
  708. checker.set_extra_opr_impl(std::bind(extra_impl,
  709. std::placeholders::_1, m,
  710. arg.param, handle));
  711. checker.set_dtype(0, A_dtype)
  712. .set_dtype(1, B_dtype)
  713. .set_dtype(2, C_dtype)
  714. .set_dtype(4, D_dtype)
  715. .set_epsilon(eps)
  716. .set_param(arg.param)
  717. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  718. }
  719. }
  720. };
  721. run(handle(), args, {6}, dtype::Float32(), dtype::Float32(),
  722. dtype::Float32(), dtype::Float32(), 1e-3f);
  723. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  724. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  725. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  726. run(handle(), args, {6}, dtype::Float16(), dtype::Float16(),
  727. dtype::Float16(), dtype::Float16(), 0.35f);
  728. #endif
  729. }
  730. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_PREPROCESS_NCHW44) {
  731. using namespace conv_bias;
  732. std::vector<TestArg> nchw44_args = get_nchw44_conv_bias_args({3}, 1);
  733. Checker<ConvBiasForward> checker(handle());
  734. auto extra_impl = [](const TensorNDArray& tensors, uint32_t m,
  735. param::ConvBias param, Handle* handle) {
  736. megdnn_assert(param.format == param::ConvBias::Format::NCHW44);
  737. auto winograd_preprocess_opr =
  738. handle->create_operator<WinogradFilterPreprocess>();
  739. winograd_preprocess_opr->param().output_block_size = m;
  740. winograd_preprocess_opr->param().format = param::MatrixMul::Format::MK4;
  741. TensorLayout filter_transform_layout;
  742. winograd_preprocess_opr->deduce_layout(tensors[1].layout,
  743. filter_transform_layout);
  744. size_t winograd_preprocess_workspace_in_bytes =
  745. winograd_preprocess_opr->get_workspace_in_bytes(
  746. tensors[1].layout, filter_transform_layout);
  747. auto conv_bias_opr = handle->create_operator<ConvBias>();
  748. conv_bias_opr->param() = param;
  749. conv_bias_opr->param().format = param::ConvBias::Format::NCHW44_WINOGRAD;
  750. conv_bias_opr->param().output_block_size = m;
  751. size_t conv_bias_workspace_in_bytes =
  752. conv_bias_opr->get_workspace_in_bytes(
  753. tensors[0].layout, filter_transform_layout,
  754. tensors[2].layout, tensors[3].layout,
  755. tensors[4].layout, nullptr);
  756. WorkspaceBundle wb(nullptr, {filter_transform_layout.span().dist_byte(),
  757. conv_bias_workspace_in_bytes,
  758. winograd_preprocess_workspace_in_bytes});
  759. wb.set(malloc(wb.total_size_in_bytes()));
  760. TensorND filter_transform_tensor(wb.get(0),
  761. std::move(filter_transform_layout));
  762. winograd_preprocess_opr->exec(tensors[1], filter_transform_tensor,
  763. wb.get_workspace(2));
  764. conv_bias_opr->exec(tensors[0], filter_transform_tensor, tensors[2],
  765. tensors[3], tensors[4], nullptr,
  766. wb.get_workspace(1));
  767. free(wb.ptr());
  768. };
  769. auto run = [&checker, &extra_impl](
  770. Handle* handle, const std::vector<TestArg>& args,
  771. const std::vector<size_t>& out_size, DType A_dtype,
  772. DType B_dtype, DType C_dtype, DType D_dtype,
  773. const float eps) {
  774. for (auto&& arg : args) {
  775. for (uint32_t m : out_size) {
  776. checker.set_extra_opr_impl(std::bind(extra_impl,
  777. std::placeholders::_1, m,
  778. arg.param, handle));
  779. checker.set_dtype(0, A_dtype)
  780. .set_dtype(1, B_dtype)
  781. .set_dtype(2, C_dtype)
  782. .set_dtype(4, D_dtype)
  783. .set_epsilon(eps)
  784. .set_param(arg.param)
  785. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  786. }
  787. }
  788. };
  789. run(handle(), nchw44_args, {2, 6}, dtype::Float32(), dtype::Float32(),
  790. dtype::Float32(), dtype::Float32(), 1e-3f);
  791. }
  792. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1) {
  793. using namespace conv_bias;
  794. Checker<ConvBiasForward> checker(handle());
  795. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  796. const std::vector<size_t>& out_size, DType A_dtype,
  797. DType B_dtype, DType C_dtype, DType D_dtype,
  798. param::MatrixMul::Format format, float eps) {
  799. for (auto&& arg : args) {
  800. for (uint32_t m : out_size) {
  801. checker.set_extra_opr_impl(std::bind(
  802. winograd_algo_extra_impl, std::placeholders::_1, m,
  803. arg.param, handle, format));
  804. checker.set_dtype(0, A_dtype)
  805. .set_dtype(1, B_dtype)
  806. .set_dtype(2, C_dtype)
  807. .set_dtype(4, D_dtype)
  808. .set_epsilon(eps)
  809. .set_param(arg.param)
  810. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  811. }
  812. }
  813. };
  814. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  815. std::vector<TestArg> args_first_half(args.begin(),
  816. args.begin() + args.size() / 2);
  817. run(handle(), args_first_half, {2, 6}, dtype::Float32{}, dtype::Float32{},
  818. dtype::Float32{}, dtype::Float32{}, param::MatrixMul::Format::MK4,
  819. 1e-3f);
  820. }
  821. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2) {
  822. using namespace conv_bias;
  823. Checker<ConvBiasForward> checker(handle());
  824. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  825. const std::vector<size_t>& out_size, DType A_dtype,
  826. DType B_dtype, DType C_dtype, DType D_dtype,
  827. param::MatrixMul::Format format, float eps) {
  828. for (auto&& arg : args) {
  829. for (uint32_t m : out_size) {
  830. checker.set_extra_opr_impl(std::bind(
  831. winograd_algo_extra_impl, std::placeholders::_1, m,
  832. arg.param, handle, format));
  833. checker.set_dtype(0, A_dtype)
  834. .set_dtype(1, B_dtype)
  835. .set_dtype(2, C_dtype)
  836. .set_dtype(4, D_dtype)
  837. .set_epsilon(eps)
  838. .set_param(arg.param)
  839. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  840. }
  841. }
  842. };
  843. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  844. std::vector<TestArg> args_second_half(args.begin() + args.size() / 2,
  845. args.end());
  846. run(handle(), args_second_half, {2, 6}, dtype::Float32{}, dtype::Float32{},
  847. dtype::Float32{}, dtype::Float32{}, param::MatrixMul::Format::MK4,
  848. 1e-3f);
  849. }
  850. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  851. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F16) {
  852. using namespace conv_bias;
  853. Checker<ConvBiasForward> checker(handle());
  854. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  855. const std::vector<size_t>& out_size, DType A_dtype,
  856. DType B_dtype, DType C_dtype, DType D_dtype,
  857. param::MatrixMul::Format format, float eps) {
  858. for (auto&& arg : args) {
  859. for (uint32_t m : out_size) {
  860. checker.set_extra_opr_impl(std::bind(
  861. winograd_algo_extra_impl, std::placeholders::_1, m,
  862. arg.param, handle, format));
  863. checker.set_dtype(0, A_dtype)
  864. .set_dtype(1, B_dtype)
  865. .set_dtype(2, C_dtype)
  866. .set_dtype(4, D_dtype)
  867. .set_epsilon(eps)
  868. .set_param(arg.param)
  869. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  870. }
  871. }
  872. };
  873. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  874. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  875. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  876. run(handle(), args, {2}, dtype::Float16{}, dtype::Float16{},
  877. dtype::Float16{}, dtype::Float16{}, param::MatrixMul::Format::MK8,
  878. 0.25);
  879. }
  880. #endif
  881. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_INT8) {
  882. using namespace conv_bias;
  883. Checker<ConvBiasForward> checker(handle());
  884. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  885. const std::vector<size_t>& out_size, DType A_dtype,
  886. DType B_dtype, DType C_dtype, DType D_dtype,
  887. param::MatrixMul::Format format, float eps) {
  888. for (auto&& arg : args) {
  889. for (uint32_t m : out_size) {
  890. checker.set_extra_opr_impl(std::bind(
  891. winograd_algo_extra_impl, std::placeholders::_1, m,
  892. arg.param, handle, format));
  893. checker.set_dtype(0, A_dtype)
  894. .set_dtype(1, B_dtype)
  895. .set_dtype(2, C_dtype)
  896. .set_dtype(4, D_dtype)
  897. .set_epsilon(eps)
  898. .set_param(arg.param)
  899. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  900. }
  901. }
  902. };
  903. #if MEGDNN_AARCH64
  904. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  905. #else
  906. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  907. #endif
  908. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  909. ssprintf("WINOGRAD:%s:8:2:32", matmul_name).c_str()));
  910. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  911. std::vector<TestArg> quantized_args =
  912. get_quantized_winograd_mk_packed_args(8);
  913. UniformIntRNG int_rng{-50, 50};
  914. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  915. run(handle(), quantized_args, {2}, dtype::QuantizedS8(2.5f),
  916. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f),
  917. dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3);
  918. }
  919. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  920. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F23) {
  921. using namespace conv_bias;
  922. std::vector<TestArg> args = get_winograd_mk_packed_args();
  923. Checker<ConvBiasForward> checker(handle());
  924. check_winograd_fp16("1:2:32", checker, args, NULL, 0.08);
  925. }
  926. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_1) {
  927. using namespace conv_bias;
  928. std::vector<TestArg> args = get_winograd_args(5);
  929. std::vector<TestArg> args_head_half(args.begin(),
  930. args.begin() + args.size() / 2);
  931. Checker<ConvBiasForward> checker(handle());
  932. //! fp16 range -1.0 ~ 1.0
  933. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  934. check_winograd_fp16("1:4:32", checker, args_head_half, rng, 0.25);
  935. }
  936. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_2) {
  937. using namespace conv_bias;
  938. std::vector<TestArg> args = get_winograd_args(5);
  939. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2,
  940. args.end());
  941. Checker<ConvBiasForward> checker(handle());
  942. //! fp16 range -1.0 ~ 1.0
  943. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  944. check_winograd_fp16("1:4:32", checker, args_back_half, rng, 0.25);
  945. }
  946. //! FIXME: This test may be failed if run `ARM_COMMON.CONV_BIAS_WINOGRAD*`, but
  947. //! it will pass when run single testcase
  948. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F63) {
  949. using namespace conv_bias;
  950. std::vector<TestArg> args = get_winograd_args(3);
  951. Checker<ConvBiasForward> checker(handle());
  952. //! fp16 range -1.0 ~ 1.0
  953. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  954. check_winograd_fp16("1:6:32", checker, args, rng, 0.3);
  955. }
  956. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_1) {
  957. using namespace conv_bias;
  958. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  959. std::vector<TestArg> args_head_half(args.begin(),
  960. args.begin() + args.size() / 2);
  961. Checker<ConvBiasForward> checker(handle());
  962. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  963. check_winograd_fp16("8:2:32", checker, args_head_half, rng, 0.25,
  964. param::MatrixMul::Format::MK8);
  965. }
  966. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_2) {
  967. using namespace conv_bias;
  968. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  969. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2,
  970. args.end());
  971. Checker<ConvBiasForward> checker(handle());
  972. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  973. check_winograd_fp16("8:2:32", checker, args_back_half, rng, 0.25,
  974. param::MatrixMul::Format::MK8);
  975. }
  976. #endif
  977. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_INT8_8X8) {
  978. using namespace conv_bias;
  979. std::vector<TestArg> args = get_quantized_winograd_mk_packed_args(8);
  980. Checker<ConvBiasForward> checker(handle());
  981. UniformIntRNG rng{-50, 50};
  982. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  983. .set_dtype(1, dtype::QuantizedS8(2.5f))
  984. .set_dtype(2, dtype::QuantizedS32(6.25f))
  985. .set_dtype(4, dtype::QuantizedS8(60.25f))
  986. .set_rng(0, &rng)
  987. .set_rng(1, &rng)
  988. .set_rng(2, &rng);
  989. check_winograd("8:2:32", checker, args, param::MatrixMul::Format::MK8);
  990. }
  991. void checker_conv_bias(std::vector<conv_bias::TestArg> args, Handle* handle,
  992. RNG* rng, float epsilon, DType type0, DType type1,
  993. DType type2, DType type3, const char* algo_name) {
  994. using namespace conv_bias;
  995. Checker<ConvBias> checker(handle);
  996. checker.set_before_exec_callback(
  997. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  998. checker.set_dtype(0, type0);
  999. checker.set_dtype(1, type1);
  1000. checker.set_dtype(2, type2);
  1001. checker.set_dtype(4, type3);
  1002. checker.set_epsilon(epsilon);
  1003. if (NULL != rng) {
  1004. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng).set_rng(3, rng);
  1005. }
  1006. for (auto&& arg : args) {
  1007. checker.set_param(arg.param).execs(
  1008. {arg.src, arg.filter, arg.bias, {}, {}});
  1009. }
  1010. }
  1011. // clang-format off
  1012. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE2) {
  1013. #define cb(name) \
  1014. check_conv_bias( \
  1015. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 2, false, false, false), \
  1016. handle(), name);
  1017. #if MEGDNN_AARCH64
  1018. cb("IM2COLMATMUL:AARCH64_F32K8X12X1")
  1019. cb("IM2COLMATMUL:AARCH64_F32K4X16X1")
  1020. cb("IM2COLMATMUL:FB_F32_K8X12X1")
  1021. #elif MEGDNN_ARMV7
  1022. cb("IM2COLMATMUL:ARMV7_F32")
  1023. #endif
  1024. #undef cb
  1025. }
  1026. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE1) {
  1027. #define cb(name) \
  1028. check_conv_bias( \
  1029. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, false), \
  1030. handle(), name);
  1031. #if MEGDNN_AARCH64
  1032. cb("IM2COLMATMUL:AARCH64_F32K8X12X1")
  1033. cb("IM2COLMATMUL:AARCH64_F32K4X16X1")
  1034. cb("IM2COLMATMUL:FB_F32_K8X12X1")
  1035. #elif MEGDNN_ARMV7
  1036. cb("IM2COLMATMUL:ARMV7_F32")
  1037. cb("IM2COLMATMUL:FB_F32_K8X12X1")
  1038. #endif
  1039. #undef cb
  1040. }
  1041. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) {
  1042. UniformIntRNG rng{-50, 50};
  1043. #define cb(name) \
  1044. checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
  1045. false, true, true), \
  1046. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1047. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1048. dtype::QuantizedS8(60.25f), name); \
  1049. checker_conv_bias( \
  1050. get_conv_bias_args({1}, 2, false, false, false, true, true), \
  1051. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1052. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1053. dtype::QuantizedS8(60.25f), name);
  1054. float epsilon = 0.001;
  1055. #if MEGDNN_AARCH64
  1056. #if __ARM_FEATURE_DOTPROD
  1057. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD");
  1058. #else
  1059. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8");
  1060. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16");
  1061. #endif
  1062. #elif MEGDNN_ARMV7
  1063. epsilon = 1;
  1064. cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8");
  1065. #endif
  1066. #undef cb
  1067. }
  1068. // clang-format on
  1069. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1070. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) {
  1071. NormalRNG rng(128.f);
  1072. #define cb(name) \
  1073. checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
  1074. false, true, true), \
  1075. handle(), &rng, epsilon, \
  1076. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1077. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1078. dtype::QuantizedS32(1.2 * 1.3), \
  1079. dtype::Quantized8Asymm(50.3f, (uint8_t)120), name); \
  1080. checker_conv_bias( \
  1081. get_conv_bias_args({1}, 2, false, false, false, true, true), \
  1082. handle(), &rng, epsilon, \
  1083. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1084. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1085. dtype::QuantizedS32(1.2 * 1.3), \
  1086. dtype::Quantized8Asymm(50.3f, (uint8_t)120), name);
  1087. float epsilon = 0.001;
  1088. #if MEGDNN_AARCH64
  1089. #if __ARM_FEATURE_DOTPROD
  1090. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD");
  1091. #else
  1092. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8");
  1093. #endif
  1094. #elif MEGDNN_ARMV7
  1095. epsilon = 1;
  1096. cb("IM2COLMATMUL:ARMV7_QUINT8_K4X8X8");
  1097. #endif
  1098. #undef cb
  1099. }
  1100. #endif
  1101. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1102. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) {
  1103. UniformIntRNG rng{-50, 50};
  1104. float epsilon = 0.001;
  1105. #define cb(name) \
  1106. checker_conv_bias( \
  1107. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \
  1108. handle(), &rng, epsilon, \
  1109. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1110. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1111. dtype::QuantizedS32(1.2 * 1.3), {}, name); \
  1112. checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \
  1113. &rng, epsilon, \
  1114. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1115. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1116. dtype::QuantizedS32(1.2 * 1.3), {}, name);
  1117. #if MEGDNN_AARCH64
  1118. #if __ARM_FEATURE_DOTPROD
  1119. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD");
  1120. #else
  1121. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8");
  1122. #endif
  1123. #elif MEGDNN_ARMV7
  1124. #if __ARM_FEATURE_DOTPROD
  1125. cb("IM2COLMATMUL:AARCH32_QUINT8_K4X8X4");
  1126. #endif
  1127. cb("IM2COLMATMUL:ARMV7_QUINT8_K4X8X8");
  1128. #endif
  1129. #undef cb
  1130. }
  1131. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) {
  1132. UniformIntRNG rng{-50, 50};
  1133. float epsilon = 0.001;
  1134. #define cb(name) \
  1135. checker_conv_bias( \
  1136. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \
  1137. handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
  1138. dtype::Int16{}, dtype::Int16{}, name); \
  1139. checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \
  1140. &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
  1141. dtype::Int16{}, dtype::Int16{}, name);
  1142. #if MEGDNN_AARCH64
  1143. cb("IM2COLMATMUL:AARCH64_INT8X8X16_K8X8X8");
  1144. cb("IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16");
  1145. cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16");
  1146. #elif MEGDNN_ARMV7
  1147. cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16");
  1148. cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8");
  1149. cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X2X16");
  1150. #endif
  1151. #undef cb
  1152. }
  1153. #endif
  1154. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  1155. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP16) {
  1156. using namespace conv_bias;
  1157. param::ConvBias cur_param;
  1158. std::vector<conv_bias::TestArg> args =
  1159. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, false);
  1160. std::vector<conv_bias::TestArg> args1 =
  1161. get_conv_bias_args({1}, 2, false, false, false);
  1162. args.insert(args.begin(), args1.begin(), args1.end());
  1163. NormalRNG rng(1);
  1164. #define cb(name) \
  1165. checker_conv_bias(args, handle(), &rng, 0.03, dtype::Float16{}, \
  1166. dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, \
  1167. name);
  1168. #if MEGDNN_AARCH64
  1169. cb("IM2COLMATMUL:AARCH64_F16_K8X24X1");
  1170. #elif MEGDNN_ARMV7
  1171. cb("IM2COLMATMUL:AARCH32_F16_K4X16X1");
  1172. #endif
  1173. #undef cb
  1174. }
  1175. #endif
  1176. void checker_conv_bias_mul_int8x8x32(std::vector<conv_bias::TestArg> args,
  1177. Handle* handle, const char* algo_name) {
  1178. using namespace conv_bias;
  1179. Checker<ConvBias> checker(handle);
  1180. checker.set_before_exec_callback(
  1181. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  1182. checker.set_dtype(0, dtype::Int8());
  1183. checker.set_dtype(1, dtype::Int8());
  1184. checker.set_dtype(2, dtype::Int32());
  1185. checker.set_dtype(4, dtype::Int32());
  1186. for (auto&& arg : args) {
  1187. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  1188. }
  1189. UniformIntRNG rng{-50, 50};
  1190. for (auto&& arg : args) {
  1191. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  1192. .set_dtype(1, dtype::QuantizedS8(2.5f))
  1193. .set_dtype(2, dtype::QuantizedS32(6.25f))
  1194. .set_dtype(4, {})
  1195. .set_rng(0, &rng)
  1196. .set_rng(1, &rng)
  1197. .set_rng(2, &rng)
  1198. .set_param(arg.param)
  1199. .execs({arg.src, arg.filter, {}, {}, {}});
  1200. }
  1201. }
  1202. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1203. #if !__ARM_FEATURE_DOTPROD
  1204. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) {
  1205. using namespace conv_bias;
  1206. std::vector<conv_bias::TestArg> args =
  1207. get_nchw44_conv_bias_args({2, 5, 7}, 2, false, true, true);
  1208. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1209. #if MEGDNN_AARCH64
  1210. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1211. #else
  1212. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1213. #endif
  1214. #undef cb
  1215. }
  1216. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1) {
  1217. using namespace conv_bias;
  1218. std::vector<conv_bias::TestArg> args =
  1219. get_nchw44_conv_bias_args({3, 4, 6}, 1, false, true, true);
  1220. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1221. #if MEGDNN_AARCH64
  1222. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1223. #else
  1224. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1225. #endif
  1226. #undef cb
  1227. }
  1228. TEST_F(ARM_COMMON_MULTI_THREADS,
  1229. CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_S2) {
  1230. UniformIntRNG rng{-50, 50};
  1231. #define cb(name) \
  1232. checker_conv_bias(get_nchw44_conv_bias_args({3, 4, 6}, 2), handle(), &rng, \
  1233. epsilon, dtype::QuantizedS8(2.5f), \
  1234. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1235. dtype::QuantizedS8(60.25f), name);
  1236. float epsilon = 0.001;
  1237. #if MEGDNN_AARCH64
  1238. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1239. #else
  1240. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1241. #endif
  1242. #undef cb
  1243. }
  1244. TEST_F(ARM_COMMON_MULTI_THREADS,
  1245. CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_S1) {
  1246. UniformIntRNG rng{-50, 50};
  1247. #define cb(name) \
  1248. checker_conv_bias(get_nchw44_conv_bias_args({2, 5, 7}, 1), handle(), &rng, \
  1249. epsilon, dtype::QuantizedS8(2.5f), \
  1250. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1251. dtype::QuantizedS8(60.25f), name);
  1252. float epsilon = 0.001;
  1253. #if MEGDNN_AARCH64
  1254. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1255. #else
  1256. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1257. #endif
  1258. #undef cb
  1259. }
  1260. #if MEGDNN_AARCH64
  1261. TEST_F(ARM_COMMON_MULTI_THREADS,
  1262. CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_FUSE) {
  1263. UniformIntRNG rng{-50, 50};
  1264. #define cb(name) \
  1265. checker_conv_bias(get_nchw44_conv_bias_args({3}, 1), handle(), &rng, \
  1266. epsilon, dtype::QuantizedS8(2.5f), \
  1267. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1268. dtype::QuantizedS8(60.25f), name);
  1269. float epsilon = 0.001;
  1270. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1271. #undef cb
  1272. }
  1273. #endif
  1274. #endif
  1275. #endif
  1276. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
  1277. using namespace conv_bias;
  1278. std::vector<conv_bias::TestArg> args =
  1279. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true);
  1280. std::vector<conv_bias::TestArg> args1 =
  1281. get_conv_bias_args({1}, 2, false, true, true);
  1282. args.insert(args.begin(), args1.begin(), args1.end());
  1283. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1284. #if MEGDNN_AARCH64
  1285. #if __ARM_FEATURE_DOTPROD
  1286. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD");
  1287. #else
  1288. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8");
  1289. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16");
  1290. #endif
  1291. #elif MEGDNN_ARMV7
  1292. #if __ARM_FEATURE_DOTPROD
  1293. cb("IM2COLMATMUL:AARCH32_INT8_K6X8X4");
  1294. #endif
  1295. cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8");
  1296. #endif
  1297. #if MEGDNN_ARMV7
  1298. cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X2X16");
  1299. #endif
  1300. #undef cb
  1301. }
  1302. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) {
  1303. using namespace conv_bias;
  1304. std::vector<conv_bias::TestArg> args =
  1305. get_nchw44_conv_bias_args({2, 4, 7}, 1);
  1306. #if MEGDNN_AARCH64
  1307. check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
  1308. #elif MEGDNN_ARMV7
  1309. check_conv_bias(args, handle(), "IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12");
  1310. #endif
  1311. }
  1312. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) {
  1313. using namespace conv_bias;
  1314. std::vector<conv_bias::TestArg> args =
  1315. get_nchw44_conv_bias_args({3, 5, 6}, 2);
  1316. #if MEGDNN_AARCH64
  1317. check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
  1318. #elif MEGDNN_ARMV7
  1319. check_conv_bias(args, handle(), "IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12");
  1320. #endif
  1321. }
  1322. /***************************** Conv1x1 Algo Test ***********************/
  1323. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) {
  1324. using namespace conv_bias;
  1325. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, false);
  1326. #if MEGDNN_AARCH64
  1327. check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32K8X12X1:24");
  1328. #elif MEGDNN_ARMV7
  1329. check_conv_bias(args, handle(), "CONV1x1:ARMV7_F32:48");
  1330. #endif
  1331. }
  1332. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) {
  1333. using namespace conv_bias;
  1334. std::vector<conv_bias::TestArg> args =
  1335. get_nchw44_conv_bias_args({1}, 1, true, false, false);
  1336. #if MEGDNN_AARCH64
  1337. check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32_MK4_K8X12X1:24");
  1338. #elif MEGDNN_ARMV7
  1339. check_conv_bias(args, handle(), "CONV1x1:ARMV7_F32_MK4_PACK_4X12:24");
  1340. #endif
  1341. }
  1342. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_NO_PACK_F32) {
  1343. using namespace conv_bias;
  1344. std::vector<conv_bias::TestArg> args =
  1345. get_nchw44_conv_bias_args({1}, 1, true, false, false);
  1346. std::vector<conv_bias::TestArg> args_of_4;
  1347. for (auto&& arg : args) {
  1348. if (arg.src.shape[2] * arg.src.shape[3] % 4 == 0) {
  1349. args_of_4.push_back(arg);
  1350. }
  1351. }
  1352. #if MEGDNN_AARCH64
  1353. check_conv_bias(args_of_4, handle(), "CONV1x1:AARCH64_F32_MK4_4x16:24");
  1354. #elif MEGDNN_ARMV7
  1355. check_conv_bias(args_of_4, handle(), "CONV1x1:ARMV7_F32_MK4_4x8:48");
  1356. #endif
  1357. }
  1358. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  1359. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16) {
  1360. using namespace conv_bias;
  1361. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, false);
  1362. NormalRNG rng(1);
  1363. #if MEGDNN_AARCH64
  1364. checker_conv_bias(args, handle(), &rng, 0.03, dtype::Float16{},
  1365. dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
  1366. "CONV1x1:AARCH64_F16_K8X24X1:48");
  1367. #elif MEGDNN_ARMV7
  1368. checker_conv_bias(args, handle(), &rng, 0.03, dtype::Float16{},
  1369. dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
  1370. "CONV1x1:AARCH32_F16_K4X16X1:24");
  1371. #endif
  1372. }
  1373. #endif
  1374. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM) {
  1375. UniformIntRNG rng{-50, 50};
  1376. float epsilon = 0.001;
  1377. #define cb(name) \
  1378. checker_conv_bias(get_conv_bias_1x1_args(false, false, true, true), \
  1379. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1380. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1381. dtype::QuantizedS8(60.25f), name);
  1382. #if MEGDNN_AARCH64
  1383. #if __ARM_FEATURE_DOTPROD
  1384. cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:24");
  1385. #else
  1386. cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24");
  1387. cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:48");
  1388. #endif
  1389. #elif MEGDNN_ARMV7
  1390. epsilon = 1;
  1391. cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:48");
  1392. #endif
  1393. #undef cb
  1394. }
  1395. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1396. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) {
  1397. NormalRNG rng(128.f);
  1398. #define cb(name) \
  1399. checker_conv_bias(get_conv_bias_1x1_args(false, false, true, true), \
  1400. handle(), &rng, epsilon, \
  1401. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1402. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1403. dtype::QuantizedS32(1.2 * 1.3), \
  1404. dtype::Quantized8Asymm(50.3f, (uint8_t)120), name);
  1405. float epsilon = 0.001;
  1406. #if MEGDNN_AARCH64
  1407. #if __ARM_FEATURE_DOTPROD
  1408. cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:48");
  1409. #else
  1410. cb("CONV1x1:AARCH64_QUINT8_K8X8X8:24");
  1411. #endif
  1412. #elif MEGDNN_ARMV7
  1413. epsilon = 1;
  1414. cb("CONV1x1:ARMV7_QUINT8_K4X8X8:48");
  1415. #endif
  1416. #undef cb
  1417. }
  1418. #endif
  1419. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1420. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32) {
  1421. UniformIntRNG rng{-50, 50};
  1422. float epsilon = 0.001;
  1423. #define cb(name) \
  1424. checker_conv_bias(get_conv_bias_1x1_args(true, true), handle(), &rng, \
  1425. epsilon, dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1426. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1427. dtype::QuantizedS32(1.2 * 1.3), {}, name);
  1428. #if MEGDNN_AARCH64
  1429. #if __ARM_FEATURE_DOTPROD
  1430. cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:24");
  1431. #else
  1432. cb("CONV1x1:AARCH64_QUINT8_K8X8X8:48");
  1433. #endif
  1434. #elif MEGDNN_ARMV7
  1435. #if __ARM_FEATURE_DOTPROD
  1436. cb("CONV1x1:AARCH32_QUINT8_K4X8X4:48");
  1437. #endif
  1438. cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24");
  1439. #endif
  1440. #undef cb
  1441. }
  1442. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) {
  1443. UniformIntRNG rng{-50, 50};
  1444. float epsilon = 0.001;
  1445. #define cb(name) \
  1446. checker_conv_bias(get_conv_bias_1x1_args(true, true), handle(), &rng, \
  1447. epsilon, dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, \
  1448. dtype::Int16{}, name);
  1449. #if MEGDNN_AARCH64
  1450. cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24");
  1451. cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24");
  1452. #elif MEGDNN_ARMV7
  1453. cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24");
  1454. cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48");
  1455. #endif
  1456. cb("CONV1x1:ARM_COMMON_INT8X8X16:48");
  1457. #undef cb
  1458. }
  1459. #endif
  1460. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) {
  1461. using namespace conv_bias;
  1462. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true);
  1463. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1464. #if MEGDNN_AARCH64
  1465. #if __ARM_FEATURE_DOTPROD
  1466. cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:48");
  1467. #else
  1468. cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24");
  1469. cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:24");
  1470. #endif
  1471. #elif MEGDNN_ARMV7
  1472. #if __ARM_FEATURE_DOTPROD
  1473. cb("CONV1x1:AARCH32_INT8_K6X8X4:48");
  1474. #endif
  1475. cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:24");
  1476. #endif
  1477. #if MEGDNN_ARMV7
  1478. cb("CONV1x1:ARMV7_INT8X8X32_K4X2X16:48");
  1479. #endif
  1480. #undef cb
  1481. }
  1482. #ifndef __ARM_FEATURE_DOTPROD
  1483. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) {
  1484. using namespace conv_bias;
  1485. std::vector<conv_bias::TestArg> args =
  1486. get_nchw44_conv_bias_args({1}, 1, true, true, true);
  1487. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1488. #if MEGDNN_AARCH64
  1489. cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24");
  1490. #elif MEGDNN_ARMV7
  1491. cb("CONV1x1:ARMV7_INT8X8X32_MK4_4X2X16:24");
  1492. #endif
  1493. #undef cb
  1494. UniformIntRNG rng{-50, 50};
  1495. float epsilon = 0.001;
  1496. #define cb(name) \
  1497. checker_conv_bias(get_nchw44_conv_bias_args({1}, 1, true, false, false), \
  1498. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1499. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1500. dtype::QuantizedS8(60.25f), name);
  1501. #if MEGDNN_AARCH64
  1502. cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24");
  1503. #elif MEGDNN_ARMV7
  1504. cb("CONV1x1:ARMV7_INT8X8X32_MK4_4X2X16:24");
  1505. #endif
  1506. #undef cb
  1507. }
  1508. #endif
  1509. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台