You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bias_multi_thread.cpp 44 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095
  1. /**
  2. * \file dnn/test/arm_common/conv_bias_multi_thread.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/dtype.h"
  13. #include "test/arm_common/fixture.h"
  14. #include "test/common/benchmarker.h"
  15. #include "test/common/conv_bias.h"
  16. #include "test/arm_common/cpuinfo_help.h"
  17. using namespace megdnn;
  18. using namespace test;
  19. using namespace conv_bias;
  20. std::vector<conv_bias::TestArg> get_int8_quint8_conv_bias_args(
  21. std::vector<size_t> kernel, size_t stride, bool no_pad, bool no_bias,
  22. bool no_nonlinemode) {
  23. using namespace conv_bias;
  24. using Param = param::ConvBias;
  25. using NLMode = param::ConvBias::NonlineMode;
  26. std::vector<TestArg> args;
  27. auto pack = [&](size_t n, size_t oc, size_t ic, size_t w, size_t h,
  28. size_t kernel, size_t stride, NLMode nlmode) {
  29. Param param;
  30. param.stride_h = stride;
  31. param.stride_w = stride;
  32. if (!no_pad) {
  33. param.pad_h = kernel / 2;
  34. param.pad_w = kernel / 2;
  35. } else {
  36. param.pad_h = 0;
  37. param.pad_w = 0;
  38. }
  39. param.nonlineMode = nlmode;
  40. args.emplace_back(param, TensorShape{n, ic, h, w},
  41. TensorShape{oc, ic, kernel, kernel}, TensorShape{});
  42. if (!no_bias) {
  43. args.emplace_back(param, TensorShape{n, ic, h, w},
  44. TensorShape{oc, ic, kernel, kernel},
  45. TensorShape{1, oc, 1, 1});
  46. }
  47. };
  48. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  49. if (!no_nonlinemode) {
  50. nonlinemode.emplace_back(NLMode::RELU);
  51. nonlinemode.emplace_back(NLMode::H_SWISH);
  52. }
  53. for (size_t n : {1, 2}) {
  54. for (auto nlmode : nonlinemode) {
  55. for (size_t ic : {1, 3, 7}) {
  56. for (size_t oc : {1, 3, 7}) {
  57. for (size_t size : {4, 6, 8, 14, 16, 18}) {
  58. for (size_t kern : kernel) {
  59. pack(n, oc, ic, size, size, kern, stride, nlmode);
  60. }
  61. }
  62. }
  63. }
  64. }
  65. }
  66. return args;
  67. }
  68. std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args(
  69. std::vector<size_t> kernel, size_t stride, bool no_bias,
  70. bool no_nonlinemode, bool no_full_bias) {
  71. using namespace conv_bias;
  72. using Param = param::ConvBias;
  73. using NLMode = param::ConvBias::NonlineMode;
  74. std::vector<TestArg> args;
  75. auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
  76. size_t stride, NLMode nlmode, bool pad) {
  77. Param param;
  78. param.stride_h = stride;
  79. param.stride_w = stride;
  80. if (pad) {
  81. param.pad_h = kernel / 2;
  82. param.pad_w = kernel / 2;
  83. } else {
  84. param.pad_h = 0;
  85. param.pad_w = 0;
  86. }
  87. param.nonlineMode = nlmode;
  88. param.format = param::ConvBias::Format::NCHW44;
  89. param.sparse = param::ConvBias::Sparse::GROUP;
  90. args.emplace_back(param, TensorShape{n, group, h, w, 4},
  91. TensorShape{group, 1, 1, kernel, kernel, 4},
  92. TensorShape{});
  93. if (!no_bias) {
  94. args.emplace_back(param, TensorShape{n, group, h, w, 4},
  95. TensorShape{group, 1, 1, kernel, kernel, 4},
  96. TensorShape{1, group, 1, 1, 4});
  97. }
  98. if (!no_full_bias) {
  99. args.emplace_back(
  100. param, TensorShape{n, group, h, w, 4},
  101. TensorShape{group, 1, 1, kernel, kernel, 4},
  102. TensorShape{n, group,
  103. (h + 2 * param.pad_w - kernel) / stride + 1,
  104. (w + 2 * param.pad_w - kernel) / stride + 1,
  105. 4});
  106. }
  107. };
  108. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  109. if (!no_nonlinemode) {
  110. nonlinemode.emplace_back(NLMode::RELU);
  111. nonlinemode.emplace_back(NLMode::H_SWISH);
  112. }
  113. for (size_t n : {1, 2}) {
  114. for (auto nlmode : nonlinemode) {
  115. for (bool pad : {true}) {
  116. for (size_t group : {1, 2, 4, 7, 128}) {
  117. for (size_t size : {4, 6, 7, 9, 15, 40}) {
  118. for (size_t kern : kernel) {
  119. pack(n, group, size, size, kern, stride, nlmode,
  120. pad);
  121. }
  122. }
  123. }
  124. }
  125. for (bool pad : {false}) {
  126. for (size_t group : {1, 2, 7, 128}) {
  127. for (size_t size : {7, 9, 15, 40}) {
  128. for (size_t kern : kernel) {
  129. pack(n, group, size, size, kern, stride, nlmode,
  130. pad);
  131. }
  132. }
  133. }
  134. }
  135. }
  136. }
  137. return args;
  138. }
  139. void checker_conv_bias_qint8x8x8(std::vector<conv_bias::TestArg> args,
  140. Handle* handle, const char* algo_name) {
  141. Checker<ConvBias> checker(handle);
  142. checker.set_before_exec_callback(
  143. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  144. #if MEGDNN_ARMV7
  145. checker.set_epsilon(1);
  146. #endif
  147. UniformIntRNG rng{-50, 50};
  148. checker.set_dtype(0, dtype::QuantizedS8(0.41113496f))
  149. .set_dtype(1, dtype::QuantizedS8(0.01887994f))
  150. .set_dtype(2, dtype::QuantizedS32(0.41113496f * 0.01887994f))
  151. .set_dtype(4, dtype::QuantizedS8(0.49550694f))
  152. .set_rng(0, &rng)
  153. .set_rng(1, &rng)
  154. .set_rng(2, &rng);
  155. for (auto&& arg : args) {
  156. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  157. }
  158. }
  159. void checker_conv_bias_qint8x8x32(std::vector<conv_bias::TestArg> args,
  160. Handle* handle, const char* algo_name) {
  161. Checker<ConvBias> checker(handle);
  162. UniformIntRNG rng{-50, 50};
  163. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  164. .set_dtype(1, dtype::QuantizedS8(2.5f))
  165. .set_dtype(2, dtype::QuantizedS32(6.25f))
  166. .set_dtype(4, {});
  167. checker.set_before_exec_callback(
  168. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  169. for (auto&& arg : args) {
  170. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  171. }
  172. }
  173. void checker_conv_bias_quint8x8x8(std::vector<conv_bias::TestArg> args,
  174. Handle* handle, const char* algo_name) {
  175. Checker<ConvBias> checker(handle);
  176. checker.set_before_exec_callback(
  177. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  178. UniformIntRNG rng(0, 255);
  179. checker.set_dtype(0, dtype::Quantized8Asymm(0.2f, 100))
  180. .set_dtype(1, dtype::Quantized8Asymm(0.2f, 120))
  181. .set_dtype(2, dtype::QuantizedS32(0.04f))
  182. .set_dtype(4, dtype::Quantized8Asymm(1.4f, 110))
  183. .set_rng(0, &rng)
  184. .set_rng(1, &rng)
  185. .set_rng(2, &rng);
  186. for (auto&& arg : args) {
  187. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  188. }
  189. }
  190. void checker_conv_bias_quint8x8x32(std::vector<conv_bias::TestArg> args,
  191. Handle* handle, const char* algo_name) {
  192. Checker<ConvBias> checker(handle);
  193. checker.set_before_exec_callback(
  194. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  195. NormalRNG rng(128.f);
  196. checker.set_rng(0, &rng).set_rng(1, &rng);
  197. checker.set_dtype(0, dtype::Quantized8Asymm(1.2f, (uint8_t)127))
  198. .set_dtype(1, dtype::Quantized8Asymm(1.3f, (uint8_t)129))
  199. .set_dtype(2, dtype::QuantizedS32(1.2 * 1.3))
  200. .set_dtype(4, {});
  201. for (auto&& arg : args) {
  202. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  203. }
  204. }
  205. void checker_conv_bias_int8x8x32_multi(std::vector<conv_bias::TestArg> args,
  206. Handle* handle, const char* algo_name) {
  207. Checker<ConvBias> checker(handle);
  208. checker.set_before_exec_callback(
  209. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  210. checker.set_dtype(0, dtype::Int8());
  211. checker.set_dtype(1, dtype::Int8());
  212. checker.set_dtype(2, dtype::Int32());
  213. checker.set_dtype(4, dtype::Int32());
  214. for (auto&& arg : args) {
  215. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  216. }
  217. }
  218. /**********************************F32 direct************************/
  219. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32) {
  220. check_conv_bias(
  221. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  222. handle(), "F32DIRECT");
  223. }
  224. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) {
  225. //! k=7 s=1
  226. check_conv_bias(get_nchw44_conv_bias_args({7}, ONLY_IDENTITY_NLMODE,
  227. BR_AND_NO_BIASMODE, 1),
  228. handle(), "F32_CONV_NCHW44_DIRECT");
  229. }
  230. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K2K3) {
  231. check_conv_bias(
  232. get_nchw44_conv_bias_args({2, 3}, FULL_NLMODE, ONLY_BR_BIASMODE, 1),
  233. handle(), "F32_CONV_NCHW44_DIRECT");
  234. }
  235. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K5) {
  236. check_conv_bias(
  237. get_nchw44_conv_bias_args({5}, FULL_NLMODE, ONLY_BR_BIASMODE, 1),
  238. handle(), "F32_CONV_NCHW44_DIRECT");
  239. }
  240. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) {
  241. check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, FULL_NLMODE,
  242. ONLY_BR_BIASMODE, 2),
  243. handle(), "F32_CONV_NCHW44_DIRECT");
  244. }
  245. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1) {
  246. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  247. handle(), "F32STRD1");
  248. }
  249. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2) {
  250. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  251. handle(), "F32STRD2");
  252. }
  253. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S2) {
  254. check_conv_bias(
  255. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  256. ONLY_BR_BIASMODE, 2, false, true),
  257. handle(), "F32_CONV_NCHW_NCHW44");
  258. }
  259. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S1) {
  260. check_conv_bias(
  261. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  262. ONLY_BR_BIASMODE, 1, false, true),
  263. handle(), "F32_CONV_NCHW_NCHW44");
  264. }
  265. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_1) {
  266. check_conv_bias(
  267. get_nchw44_channel_wise_args({2, 3}, 1, false, false, false),
  268. handle(), "F32_CHANNEL_WISE_NCHW44");
  269. }
  270. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_2) {
  271. check_conv_bias(get_nchw44_channel_wise_args({5}, 1, false, false, false),
  272. handle(), "F32_CHANNEL_WISE_NCHW44");
  273. }
  274. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP32_NCHW44) {
  275. check_conv_bias(
  276. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, false),
  277. handle(), "F32_CHANNEL_WISE_NCHW44");
  278. }
  279. /**********************************F16 direct************************/
  280. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  281. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16) {
  282. NormalRNG rng(1);
  283. checker_conv_bias_f16(
  284. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  285. handle(), rng, "F16DIRECT", 0.03);
  286. }
  287. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1) {
  288. NormalRNG rng(1);
  289. checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false),
  290. handle(), rng, "F16STRD1", 0.03);
  291. }
  292. #endif
  293. /**********************************algo 8816 direct************************/
  294. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT) {
  295. checker_conv_bias_int8x8x16(
  296. get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(),
  297. "I8816DIRECT");
  298. }
  299. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2) {
  300. checker_conv_bias_int8x8x16(
  301. get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
  302. "I8816STRD2");
  303. }
  304. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_NCHW_NCHW44_S2) {
  305. checker_conv_bias_int8x8x16(
  306. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  307. ONLY_NO_BIASMODE, 2, false, true),
  308. handle(), "I8816_CONV_NCHW_NCHW44");
  309. }
  310. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_NCHW_NCHW44_S1) {
  311. checker_conv_bias_int8x8x16(
  312. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  313. ONLY_NO_BIASMODE, 1, false, true),
  314. handle(), "I8816_CONV_NCHW_NCHW44");
  315. }
  316. /**********************************algo 8-8-32 direct************************/
  317. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1) {
  318. checker_conv_bias_int8x8x32_multi(
  319. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  320. "S8STRD1");
  321. }
  322. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2) {
  323. checker_conv_bias_int8x8x32_multi(
  324. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  325. "S8STRD2");
  326. }
  327. TEST_F(ARM_COMMON_MULTI_THREADS,
  328. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT1_NCHW44) {
  329. checker_conv_bias_int8x8x32_multi(
  330. get_nchw44_channel_wise_args({2, 3, 5}, 1, false, true, true),
  331. handle(), "S8_CHAN_WISE_STRD1_NCHW44");
  332. }
  333. TEST_F(ARM_COMMON_MULTI_THREADS,
  334. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT2_NCHW44) {
  335. checker_conv_bias_int8x8x32_multi(
  336. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, true, true),
  337. handle(), "S8_CHAN_WISE_STRD2_NCHW44");
  338. }
  339. TEST_F(ARM_COMMON, CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT1_NCHW44) {
  340. Checker<ConvBias> checker(handle());
  341. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  342. "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"));
  343. checker.set_dtype(0, dtype::Int8());
  344. checker.set_dtype(1, dtype::Int8());
  345. checker.set_dtype(2, dtype::Int16());
  346. checker.set_dtype(4, dtype::Int16());
  347. auto args = get_nchw44_channel_wise_args({2, 3, 5}, 1, false, true, true);
  348. for (auto&& arg : args) {
  349. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  350. }
  351. }
  352. TEST_F(ARM_COMMON_MULTI_THREADS,
  353. CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT2_NCHW44) {
  354. Checker<ConvBias> checker(handle());
  355. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  356. "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"));
  357. checker.set_dtype(0, dtype::Int8());
  358. checker.set_dtype(1, dtype::Int8());
  359. checker.set_dtype(2, dtype::Int16());
  360. checker.set_dtype(4, dtype::Int16());
  361. auto args = get_nchw44_channel_wise_args({2, 3, 5}, 2, false, true, true);
  362. for (auto&& arg : args) {
  363. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  364. }
  365. }
  366. /********************************qint8 direct******************************/
  367. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1) {
  368. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  369. {2, 3, 5, 7}, 1, false, false, false),
  370. handle(), "S8STRD1");
  371. }
  372. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2) {
  373. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  374. {2, 3, 5, 7}, 2, false, false, false),
  375. handle(), "S8STRD2");
  376. }
  377. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) {
  378. checker_conv_bias_qint8x8x8(
  379. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE,
  380. ONLY_BR_BIASMODE, 1),
  381. handle(), "S8_NCHW44_DIRECT");
  382. }
  383. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8816) {
  384. checker_conv_bias_int8x8x16(
  385. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  386. ONLY_BR_BIASMODE, 1),
  387. handle(), "S8x8x16_NCHW44_DIRECT");
  388. }
  389. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8816) {
  390. checker_conv_bias_int8x8x16(
  391. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  392. ONLY_BR_BIASMODE, 2),
  393. handle(), "S8x8x16_NCHW44_DIRECT");
  394. }
  395. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8832) {
  396. checker_conv_bias_qint8x8x32(
  397. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  398. ONLY_BR_BIASMODE, 1),
  399. handle(), "S8_NCHW44_DIRECT");
  400. }
  401. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8832) {
  402. checker_conv_bias_qint8x8x32(
  403. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  404. ONLY_NO_BIASMODE, 2),
  405. handle(), "S8_NCHW44_DIRECT");
  406. }
  407. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) {
  408. checker_conv_bias_qint8x8x8(
  409. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE,
  410. BR_AND_NO_BIASMODE, 2),
  411. handle(), "S8_NCHW44_DIRECT");
  412. }
  413. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) {
  414. checker_conv_bias_qint8x8x8(
  415. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  416. BR_AND_NO_BIASMODE, 1),
  417. handle(), "S8_NCHW44_DIRECT");
  418. checker_conv_bias_qint8x8x8(
  419. get_nchw44_channel_wise_args({2, 3, 5}, 1, false, false, true),
  420. handle(), "S8_CHAN_WISE_STRD1_NCHW44");
  421. }
  422. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) {
  423. checker_conv_bias_qint8x8x8(
  424. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, true),
  425. handle(), "S8_CHAN_WISE_STRD2_NCHW44");
  426. }
  427. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S1) {
  428. checker_conv_bias_qint8x8x8(
  429. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE,
  430. BR_AND_NO_BIASMODE, 1, false, true),
  431. handle(), "S8_CONV_NCHW_NCHW44");
  432. }
  433. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S2) {
  434. checker_conv_bias_qint8x8x8(
  435. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE,
  436. BR_AND_NO_BIASMODE, 2, false, true),
  437. handle(), "S8_CONV_NCHW_NCHW44");
  438. }
  439. /*****************************quint8 direct****************************/
  440. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1) {
  441. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  442. {2, 3, 5, 7}, 1, false, false, false),
  443. handle(), "QU8STRD1");
  444. }
  445. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) {
  446. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  447. {2, 3, 5, 7}, 2, false, false, false),
  448. handle(), "QU8STRD2");
  449. }
  450. /****************************dot qint8 direct*************************/
  451. #if MGB_ENABLE_DOT
  452. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) {
  453. auto args = get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE,
  454. BR_AND_NO_BIASMODE, 2, false, true);
  455. for (auto&& arg : args) {
  456. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  457. }
  458. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  459. args = get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE,
  460. BR_AND_NO_BIASMODE, 1, false, true);
  461. for (auto&& arg : args) {
  462. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  463. }
  464. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  465. }
  466. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) {
  467. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  468. {2, 3, 5, 7}, 1, false, false, false),
  469. handle(), "ARMDOTS8STRD1");
  470. }
  471. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_WITHDOTPROD) {
  472. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  473. {2, 3, 5, 7}, 2, false, false, false),
  474. handle(), "ARMDOTS8STRD2");
  475. }
  476. /****************************dot 8-8-32 direct*************************/
  477. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT) {
  478. checker_conv_bias_qint8x8x32(
  479. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  480. "ARMDOTS8STRD1");
  481. }
  482. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT) {
  483. checker_conv_bias_qint8x8x32(
  484. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  485. "ARMDOTS8STRD2");
  486. }
  487. /******************************dot quint8*****************************/
  488. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD) {
  489. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  490. {2, 3, 5, 7}, 1, false, false, false),
  491. handle(), "ARMDOTU8STRD1");
  492. }
  493. //! TODO: this test without test kernel size=3, add it will case buss error now
  494. //! in armv7
  495. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD) {
  496. checker_conv_bias_quint8x8x8(
  497. get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false),
  498. handle(), "ARMDOTU8STRD2");
  499. }
  500. /******************************dot quint8x8x32***********************/
  501. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1) {
  502. checker_conv_bias_quint8x8x32(
  503. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  504. "ARMDOTU8STRD1");
  505. }
  506. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2) {
  507. checker_conv_bias_quint8x8x32(
  508. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  509. "ARMDOTU8STRD2");
  510. }
  511. /******************************dot int8x8x8 nchw44 ***********************/
  512. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x8) {
  513. using namespace conv_bias;
  514. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  515. {2, 3, 5, 7}, QUAN_NLMODE, ONLY_BR_BIASMODE, 1);
  516. for (auto&& arg : args)
  517. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  518. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  519. }
  520. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x32) {
  521. using namespace conv_bias;
  522. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  523. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1);
  524. for (auto&& arg : args)
  525. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  526. checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  527. }
  528. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_8x8x32) {
  529. using namespace conv_bias;
  530. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  531. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1);
  532. for (auto&& arg : args)
  533. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  534. checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  535. }
  536. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x8) {
  537. using namespace conv_bias;
  538. //! test qint8x8x8
  539. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  540. {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2);
  541. for (auto&& arg : args)
  542. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  543. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  544. }
  545. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x32) {
  546. using namespace conv_bias;
  547. //! test qint8x8x8
  548. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  549. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 2);
  550. for (auto&& arg : args)
  551. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  552. checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  553. }
  554. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_8x8x32) {
  555. using namespace conv_bias;
  556. //! test qint8x8x8
  557. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  558. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 2);
  559. for (auto&& arg : args)
  560. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  561. checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  562. }
  563. #endif
  564. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) {
  565. using namespace conv_bias;
  566. std::vector<TestArg> args = get_winograd_args(3);
  567. Checker<ConvBiasForward> checker(handle());
  568. auto run = [&checker](const std::vector<TestArg>& args, DType A_dtype,
  569. DType B_dtype, DType C_dtype, DType D_dtype,
  570. const float eps) {
  571. for (auto&& arg : args) {
  572. checker.set_dtype(0, A_dtype)
  573. .set_dtype(1, B_dtype)
  574. .set_dtype(2, C_dtype)
  575. .set_dtype(4, D_dtype)
  576. .set_epsilon(eps)
  577. .set_param(arg.param)
  578. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  579. }
  580. };
  581. run(args, dtype::Float32(), dtype::Float32(), dtype::Float32(),
  582. dtype::Float32(), 1e-3f);
  583. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  584. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  585. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  586. run(args, dtype::Float16(), dtype::Float16(), dtype::Float16(),
  587. dtype::Float16(), 0.35f);
  588. #endif
  589. }
  590. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4) {
  591. using namespace conv_bias;
  592. std::vector<TestArg> args = get_winograd_mk_packed_args();
  593. Checker<ConvBiasForward> checker(handle());
  594. check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4);
  595. }
  596. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4_NCHW44) {
  597. using namespace conv_bias;
  598. std::vector<TestArg> args =
  599. get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1);
  600. Checker<ConvBiasForward> checker(handle());
  601. check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4,
  602. param::ConvBias::Format::NCHW44);
  603. }
  604. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63) {
  605. using namespace conv_bias;
  606. std::vector<TestArg> args = get_winograd_args(3);
  607. Checker<ConvBiasForward> checker(handle());
  608. check_winograd("1:6:32", checker, args);
  609. }
  610. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) {
  611. using namespace conv_bias;
  612. std::vector<TestArg> args = get_winograd_mk_packed_args();
  613. Checker<ConvBiasForward> checker(handle());
  614. check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4);
  615. }
  616. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44) {
  617. using namespace conv_bias;
  618. std::vector<TestArg> args =
  619. get_nchw44_conv_bias_args({3},QUAN_NLMODE,BR_AND_NO_BIASMODE,1);
  620. Checker<ConvBiasForward> checker(handle());
  621. check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4,
  622. param::ConvBias::Format::NCHW44);
  623. }
  624. //! uncomment it when low precision mode is ok
  625. #if 0
  626. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) {
  627. using namespace conv_bias;
  628. std::vector<TestArg> args =
  629. get_nchw44_conv_bias_args({3},QUAN_NLMODE,BR_AND_NO_BIASMODE,1);
  630. Checker<ConvBiasForward> checker(handle());
  631. check_winograd("4:7:16", checker, args, param::MatrixMul::Format::MK4,
  632. param::ConvBias::Format::NCHW44);
  633. }
  634. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44_WEIGHT_PREPROCESS) {
  635. using namespace conv_bias;
  636. std::vector<TestArg> args =
  637. get_nchw44_conv_bias_args({3},QUAN_NLMODE,BR_AND_NO_BIASMODE,1);
  638. Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
  639. handle());
  640. check_winograd("4:7:16", checker, args, param::MatrixMul::Format::MK4,
  641. param::ConvBias::Format::NCHW44);
  642. }
  643. #endif
  644. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54) {
  645. using namespace conv_bias;
  646. std::vector<TestArg> args = get_winograd_args(4);
  647. Checker<ConvBiasForward> checker(handle());
  648. check_winograd("1:5:32", checker, args);
  649. }
  650. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45) {
  651. using namespace conv_bias;
  652. std::vector<TestArg> args = get_winograd_args(5);
  653. Checker<ConvBiasForward> checker(handle());
  654. check_winograd("1:4:32", checker, args);
  655. }
  656. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1) {
  657. using namespace conv_bias;
  658. Checker<ConvBiasForward> checker(handle());
  659. auto run = [&checker](const std::vector<TestArg>& args, DType A_dtype,
  660. DType B_dtype, DType C_dtype, DType D_dtype,
  661. float eps) {
  662. for (auto&& arg : args) {
  663. checker.set_dtype(0, A_dtype)
  664. .set_dtype(1, B_dtype)
  665. .set_dtype(2, C_dtype)
  666. .set_dtype(4, D_dtype)
  667. .set_epsilon(eps)
  668. .set_param(arg.param)
  669. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  670. }
  671. };
  672. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  673. std::vector<TestArg> args_first_half(args.begin(),
  674. args.begin() + args.size() / 2);
  675. run(args_first_half, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
  676. dtype::Float32{}, 1e-3f);
  677. }
  678. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2) {
  679. using namespace conv_bias;
  680. Checker<ConvBiasForward> checker(handle());
  681. auto run = [&checker](const std::vector<TestArg>& args, DType A_dtype,
  682. DType B_dtype, DType C_dtype, DType D_dtype,
  683. float eps) {
  684. for (auto&& arg : args) {
  685. checker.set_dtype(0, A_dtype)
  686. .set_dtype(1, B_dtype)
  687. .set_dtype(2, C_dtype)
  688. .set_dtype(4, D_dtype)
  689. .set_epsilon(eps)
  690. .set_param(arg.param)
  691. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  692. }
  693. };
  694. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  695. std::vector<TestArg> args_second_half(args.begin() + args.size() / 2,
  696. args.end());
  697. run(args_second_half, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
  698. dtype::Float32{}, 1e-3f);
  699. }
  700. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  701. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F16) {
  702. using namespace conv_bias;
  703. Checker<ConvBiasForward> checker(handle());
  704. auto run = [&checker](const std::vector<TestArg>& args, DType A_dtype,
  705. DType B_dtype, DType C_dtype, DType D_dtype,
  706. float eps) {
  707. for (auto&& arg : args) {
  708. checker.set_dtype(0, A_dtype)
  709. .set_dtype(1, B_dtype)
  710. .set_dtype(2, C_dtype)
  711. .set_dtype(4, D_dtype)
  712. .set_epsilon(eps)
  713. .set_param(arg.param)
  714. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  715. }
  716. };
  717. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  718. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  719. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  720. run(args, dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
  721. dtype::Float16{}, 0.25);
  722. }
  723. #endif
  724. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_INT8) {
  725. using namespace conv_bias;
  726. Checker<ConvBiasForward> checker(handle());
  727. auto run = [&checker](const std::vector<TestArg>& args, DType A_dtype,
  728. DType B_dtype, DType C_dtype, DType D_dtype,
  729. float eps) {
  730. for (auto&& arg : args) {
  731. checker.set_dtype(0, A_dtype)
  732. .set_dtype(1, B_dtype)
  733. .set_dtype(2, C_dtype)
  734. .set_dtype(4, D_dtype)
  735. .set_epsilon(eps)
  736. .set_param(arg.param)
  737. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  738. }
  739. };
  740. #if MEGDNN_AARCH64
  741. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  742. #else
  743. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  744. #endif
  745. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  746. ssprintf("WINOGRAD:%s:8:2:32", matmul_name).c_str()));
  747. std::vector<TestArg> quantized_args =
  748. get_quantized_winograd_mk_packed_args(8);
  749. UniformIntRNG int_rng{-50, 50};
  750. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  751. run(quantized_args, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
  752. dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), 1e-3);
  753. }
  754. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) {
  755. using namespace conv_bias;
  756. Checker<ConvBiasForward> checker(handle());
  757. auto run = [&checker](const std::vector<TestArg>& args,
  758. DType A_dtype,
  759. DType B_dtype, DType C_dtype, DType D_dtype,
  760. float eps) {
  761. for (auto&& arg : args) {
  762. checker.set_dtype(0, A_dtype)
  763. .set_dtype(1, B_dtype)
  764. .set_dtype(2, C_dtype)
  765. .set_dtype(4, D_dtype)
  766. .set_epsilon(eps)
  767. .set_param(arg.param)
  768. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  769. }
  770. };
  771. #if MEGDNN_AARCH64
  772. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  773. #else
  774. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  775. #endif
  776. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  777. ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str()));
  778. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4);
  779. UniformIntRNG int_rng{-50, 50};
  780. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  781. run(quantized_args, dtype::QuantizedS8(2.5f),
  782. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f),
  783. dtype::QuantizedS8(60.25f),1e-3);
  784. }
  785. TEST_F(ARM_COMMON_MULTI_THREADS,
  786. CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32_GROUPMODE) {
  787. using namespace conv_bias;
  788. Checker<ConvBiasForward> checker(handle());
  789. auto run = [&checker](const std::vector<TestArg>& args, DType A_dtype,
  790. DType B_dtype, DType C_dtype, DType D_dtype,
  791. float eps) {
  792. for (auto&& arg : args) {
  793. checker.set_dtype(0, A_dtype)
  794. .set_dtype(1, B_dtype)
  795. .set_dtype(2, C_dtype)
  796. .set_dtype(4, D_dtype)
  797. .set_epsilon(eps)
  798. .set_param(arg.param)
  799. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  800. }
  801. };
  802. float epsilon = 0.001;
  803. #if MEGDNN_AARCH64
  804. const char* matmul_name = "AARCH64_F32_MK4_4x16";
  805. #else
  806. const char* matmul_name = "ARMV7_F32_MK4_4x8";
  807. #endif
  808. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  809. ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str()));
  810. std::vector<TestArg> quantized_args =
  811. get_int8_nchw44_args(3, 4, true, true);
  812. UniformIntRNG int_rng{-50, 50};
  813. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  814. run(quantized_args, dtype::QuantizedS8(0.41113496f),
  815. dtype::QuantizedS8(0.01887994f),
  816. dtype::QuantizedS32(0.41113496f * 0.01887994f),
  817. dtype::QuantizedS8(0.49550694f), epsilon);
  818. }
  819. TEST_F(ARM_COMMON_MULTI_THREADS,
  820. CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPMODE) {
  821. using namespace conv_bias;
  822. Checker<ConvBiasForward> checker(handle());
  823. auto run = [&checker](const std::vector<TestArg>& args, DType A_dtype,
  824. DType B_dtype, DType C_dtype, DType D_dtype,
  825. float eps) {
  826. for (auto&& arg : args) {
  827. checker.set_dtype(0, A_dtype)
  828. .set_dtype(1, B_dtype)
  829. .set_dtype(2, C_dtype)
  830. .set_dtype(4, D_dtype)
  831. .set_epsilon(eps)
  832. .set_param(arg.param)
  833. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  834. }
  835. };
  836. #if MEGDNN_AARCH64
  837. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  838. #else
  839. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  840. #endif
  841. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  842. ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str()));
  843. std::vector<TestArg> quantized_args =
  844. get_int8_nchw44_args(3, 4, false, true);
  845. UniformIntRNG int_rng{-50, 50};
  846. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  847. run(quantized_args, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
  848. dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), 1e-3);
  849. }
  850. TEST_F(ARM_COMMON_MULTI_THREADS,
  851. CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32) {
  852. using namespace conv_bias;
  853. Checker<ConvBiasForward> checker(handle());
  854. auto run = [&checker](const std::vector<TestArg>& args, DType A_dtype,
  855. DType B_dtype, DType C_dtype, DType D_dtype,
  856. float eps) {
  857. for (auto&& arg : args) {
  858. checker.set_dtype(0, A_dtype)
  859. .set_dtype(1, B_dtype)
  860. .set_dtype(2, C_dtype)
  861. .set_dtype(4, D_dtype)
  862. .set_epsilon(eps)
  863. .set_param(arg.param)
  864. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  865. }
  866. };
  867. float epsilon = 0.001;
  868. #if MEGDNN_AARCH64
  869. const char* matmul_name = "AARCH64_F32_MK4_4x16";
  870. #else
  871. const char* matmul_name = "ARMV7_F32_MK4_4x8";
  872. #endif
  873. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  874. ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str()));
  875. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4, true);
  876. UniformIntRNG int_rng{-50, 50};
  877. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  878. run(quantized_args, dtype::QuantizedS8(0.41113496f),
  879. dtype::QuantizedS8(0.01887994f),
  880. dtype::QuantizedS32(0.41113496f * 0.01887994f),
  881. dtype::QuantizedS8(0.49550694f), epsilon);
  882. }
  883. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  884. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F23) {
  885. using namespace conv_bias;
  886. std::vector<TestArg> args = get_winograd_mk_packed_args();
  887. Checker<ConvBiasForward> checker(handle());
  888. check_winograd_fp16("1:2:32", checker, args, NULL, 0.08);
  889. }
  890. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_1) {
  891. using namespace conv_bias;
  892. std::vector<TestArg> args = get_winograd_args(5);
  893. std::vector<TestArg> args_head_half(args.begin(),
  894. args.begin() + args.size() / 2);
  895. Checker<ConvBiasForward> checker(handle());
  896. //! fp16 range -1.0 ~ 1.0
  897. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  898. check_winograd_fp16("1:4:32", checker, args_head_half, rng, 0.25);
  899. }
  900. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_2) {
  901. using namespace conv_bias;
  902. std::vector<TestArg> args = get_winograd_args(5);
  903. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2,
  904. args.end());
  905. Checker<ConvBiasForward> checker(handle());
  906. //! fp16 range -1.0 ~ 1.0
  907. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  908. check_winograd_fp16("1:4:32", checker, args_back_half, rng, 0.25);
  909. }
  910. //! FIXME: This test may be failed if run `ARM_COMMON.CONV_BIAS_WINOGRAD*`, but
  911. //! it will pass when run single testcase
  912. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F63) {
  913. using namespace conv_bias;
  914. std::vector<TestArg> args = get_winograd_args(3);
  915. Checker<ConvBiasForward> checker(handle());
  916. //! fp16 range -1.0 ~ 1.0
  917. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  918. check_winograd_fp16("1:6:32", checker, args, rng, 0.3);
  919. }
  920. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_1) {
  921. using namespace conv_bias;
  922. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  923. std::vector<TestArg> args_head_half(args.begin(),
  924. args.begin() + args.size() / 2);
  925. Checker<ConvBiasForward> checker(handle());
  926. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  927. check_winograd_fp16("8:2:32", checker, args_head_half, rng, 0.25,
  928. param::MatrixMul::Format::MK8);
  929. }
  930. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_2) {
  931. using namespace conv_bias;
  932. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  933. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2,
  934. args.end());
  935. Checker<ConvBiasForward> checker(handle());
  936. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  937. check_winograd_fp16("8:2:32", checker, args_back_half, rng, 0.25,
  938. param::MatrixMul::Format::MK8);
  939. }
  940. #endif
  941. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_INT8_8X8) {
  942. using namespace conv_bias;
  943. std::vector<TestArg> args = get_quantized_winograd_mk_packed_args(8);
  944. Checker<ConvBiasForward> checker(handle());
  945. UniformIntRNG rng{-50, 50};
  946. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  947. .set_dtype(1, dtype::QuantizedS8(2.5f))
  948. .set_dtype(2, dtype::QuantizedS32(6.25f))
  949. .set_dtype(4, dtype::QuantizedS8(60.25f))
  950. .set_rng(0, &rng)
  951. .set_rng(1, &rng)
  952. .set_rng(2, &rng);
  953. check_winograd("8:2:32", checker, args, param::MatrixMul::Format::MK8);
  954. }
  955. TEST_F(ARM_COMMON_MULTI_THREADS,
  956. CONV_BIAS_WINOGRAD_INT8_8X8_WEIGHT_PREPROCESS) {
  957. using namespace conv_bias;
  958. std::vector<TestArg> args = get_quantized_winograd_mk_packed_args(8);
  959. Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
  960. handle());
  961. UniformIntRNG rng{-50, 50};
  962. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  963. .set_dtype(1, dtype::QuantizedS8(2.5f))
  964. .set_dtype(2, dtype::QuantizedS32(6.25f))
  965. .set_dtype(4, dtype::QuantizedS8(60.25f))
  966. .set_rng(0, &rng)
  967. .set_rng(1, &rng)
  968. .set_rng(2, &rng);
  969. check_winograd("8:2:32", checker, args, param::MatrixMul::Format::MK8);
  970. }
  971. // clang-format on
  972. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台