You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bias_multi_thread.cpp 46 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164
  1. #include "megdnn/dtype.h"
  2. #include "test/arm_common/fixture.h"
  3. #include "test/common/benchmarker.h"
  4. #include "test/common/conv_bias.h"
  5. #include "test/arm_common/cpuinfo_help.h"
  6. using namespace megdnn;
  7. using namespace test;
  8. using namespace conv_bias;
  9. std::vector<conv_bias::TestArg> get_int8_quint8_conv_bias_args(
  10. std::vector<size_t> kernel, size_t stride, bool no_pad, bool no_bias,
  11. bool no_nonlinemode) {
  12. using namespace conv_bias;
  13. using Param = param::ConvBias;
  14. using NLMode = param::ConvBias::NonlineMode;
  15. std::vector<TestArg> args;
  16. auto pack = [&](size_t n, size_t oc, size_t ic, size_t w, size_t h, size_t kernel,
  17. size_t stride, NLMode nlmode) {
  18. Param param;
  19. param.stride_h = stride;
  20. param.stride_w = stride;
  21. if (!no_pad) {
  22. param.pad_h = kernel / 2;
  23. param.pad_w = kernel / 2;
  24. } else {
  25. param.pad_h = 0;
  26. param.pad_w = 0;
  27. }
  28. param.nonlineMode = nlmode;
  29. args.emplace_back(
  30. param, TensorShape{n, ic, h, w}, TensorShape{oc, ic, kernel, kernel},
  31. TensorShape{});
  32. if (!no_bias) {
  33. args.emplace_back(
  34. param, TensorShape{n, ic, h, w},
  35. TensorShape{oc, ic, kernel, kernel}, TensorShape{1, oc, 1, 1});
  36. }
  37. };
  38. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  39. if (!no_nonlinemode) {
  40. nonlinemode.emplace_back(NLMode::RELU);
  41. nonlinemode.emplace_back(NLMode::H_SWISH);
  42. }
  43. for (size_t n : {1, 2}) {
  44. for (auto nlmode : nonlinemode) {
  45. for (size_t ic : {1, 3, 7}) {
  46. for (size_t oc : {1, 3, 7}) {
  47. for (size_t size : {4, 6, 8, 14, 16, 18}) {
  48. for (size_t kern : kernel) {
  49. pack(n, oc, ic, size, size, kern, stride, nlmode);
  50. }
  51. }
  52. }
  53. }
  54. }
  55. }
  56. return args;
  57. }
  58. std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args(
  59. std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode,
  60. bool no_full_bias) {
  61. using namespace conv_bias;
  62. using Param = param::ConvBias;
  63. using NLMode = param::ConvBias::NonlineMode;
  64. std::vector<TestArg> args;
  65. auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
  66. size_t stride, NLMode nlmode, bool pad) {
  67. Param param;
  68. param.stride_h = stride;
  69. param.stride_w = stride;
  70. if (pad) {
  71. param.pad_h = kernel / 2;
  72. param.pad_w = kernel / 2;
  73. } else {
  74. param.pad_h = 0;
  75. param.pad_w = 0;
  76. }
  77. param.nonlineMode = nlmode;
  78. param.format = param::ConvBias::Format::NCHW44;
  79. param.sparse = param::ConvBias::Sparse::GROUP;
  80. args.emplace_back(
  81. param, TensorShape{n, group, h, w, 4},
  82. TensorShape{group, 1, 1, kernel, kernel, 4}, TensorShape{});
  83. if (!no_bias) {
  84. args.emplace_back(
  85. param, TensorShape{n, group, h, w, 4},
  86. TensorShape{group, 1, 1, kernel, kernel, 4},
  87. TensorShape{1, group, 1, 1, 4});
  88. }
  89. if (!no_full_bias) {
  90. args.emplace_back(
  91. param, TensorShape{n, group, h, w, 4},
  92. TensorShape{group, 1, 1, kernel, kernel, 4},
  93. TensorShape{
  94. n, group, (h + 2 * param.pad_w - kernel) / stride + 1,
  95. (w + 2 * param.pad_w - kernel) / stride + 1, 4});
  96. }
  97. };
  98. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  99. if (!no_nonlinemode) {
  100. nonlinemode.emplace_back(NLMode::RELU);
  101. nonlinemode.emplace_back(NLMode::H_SWISH);
  102. }
  103. for (size_t n : {1, 2}) {
  104. for (auto nlmode : nonlinemode) {
  105. for (bool pad : {true}) {
  106. for (size_t group : {1, 2, 4, 7, 16}) {
  107. for (size_t size : {4, 6, 7, 9, 20}) {
  108. for (size_t kern : kernel) {
  109. pack(n, group, size, size, kern, stride, nlmode, pad);
  110. }
  111. }
  112. }
  113. }
  114. for (bool pad : {false}) {
  115. for (size_t group : {1, 2, 7, 16}) {
  116. for (size_t size : {7, 9, 20}) {
  117. for (size_t kern : kernel) {
  118. pack(n, group, size, size, kern, stride, nlmode, pad);
  119. }
  120. }
  121. }
  122. }
  123. }
  124. }
  125. return args;
  126. }
  127. std::vector<conv_bias::TestArg> get_channel_wise_args(
  128. std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode,
  129. bool no_full_bias, bool support_relu) {
  130. using namespace conv_bias;
  131. using Param = param::ConvBias;
  132. using NLMode = param::ConvBias::NonlineMode;
  133. std::vector<TestArg> args;
  134. auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
  135. size_t stride, NLMode nlmode, bool pad) {
  136. Param param;
  137. param.stride_h = stride;
  138. param.stride_w = stride;
  139. if (pad) {
  140. param.pad_h = kernel / 2;
  141. param.pad_w = kernel / 2;
  142. } else {
  143. param.pad_h = 0;
  144. param.pad_w = 0;
  145. }
  146. param.nonlineMode = nlmode;
  147. param.format = param::ConvBias::Format::NCHW;
  148. param.sparse = param::ConvBias::Sparse::GROUP;
  149. args.emplace_back(
  150. param, TensorShape{n, group, h, w},
  151. TensorShape{group, 1, 1, kernel, kernel}, TensorShape{});
  152. if (!no_bias) {
  153. args.emplace_back(
  154. param, TensorShape{n, group, h, w},
  155. TensorShape{group, 1, 1, kernel, kernel},
  156. TensorShape{1, group, 1, 1});
  157. }
  158. if (!no_full_bias) {
  159. args.emplace_back(
  160. param, TensorShape{n, group, h, w},
  161. TensorShape{group, 1, 1, kernel, kernel},
  162. TensorShape{
  163. n, group, (h + 2 * param.pad_w - kernel) / stride + 1,
  164. (w + 2 * param.pad_w - kernel) / stride + 1});
  165. }
  166. };
  167. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  168. if (!no_nonlinemode) {
  169. nonlinemode.emplace_back(NLMode::RELU);
  170. nonlinemode.emplace_back(NLMode::H_SWISH);
  171. } else if (support_relu) {
  172. nonlinemode.emplace_back(NLMode::RELU);
  173. }
  174. for (size_t n : {1, 2}) {
  175. for (auto nlmode : nonlinemode) {
  176. for (bool pad : {true}) {
  177. for (size_t group : {1, 3, 7}) {
  178. for (size_t size : {4, 6, 7, 9, 16, 20, 32, 55}) {
  179. for (size_t kern : kernel) {
  180. pack(n, group, size, size, kern, stride, nlmode, pad);
  181. }
  182. }
  183. }
  184. }
  185. for (bool pad : {false}) {
  186. for (size_t group : {7}) {
  187. for (size_t size : {37}) {
  188. for (size_t kern : kernel) {
  189. if (size < kern)
  190. continue;
  191. pack(n, group, size, size, kern, stride, nlmode, pad);
  192. }
  193. }
  194. }
  195. }
  196. }
  197. }
  198. return args;
  199. }
  200. std::vector<conv_bias::TestArg> get_nchw88_channel_wise_args(
  201. std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode,
  202. bool no_full_bias) {
  203. using namespace conv_bias;
  204. using Param = param::ConvBias;
  205. using NLMode = param::ConvBias::NonlineMode;
  206. std::vector<TestArg> args;
  207. auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
  208. size_t stride, NLMode nlmode, bool pad) {
  209. Param param;
  210. param.stride_h = stride;
  211. param.stride_w = stride;
  212. if (pad) {
  213. param.pad_h = kernel / 2;
  214. param.pad_w = kernel / 2;
  215. } else {
  216. param.pad_h = 0;
  217. param.pad_w = 0;
  218. }
  219. param.nonlineMode = nlmode;
  220. param.format = param::ConvBias::Format::NCHW88;
  221. param.sparse = param::ConvBias::Sparse::GROUP;
  222. args.emplace_back(
  223. param, TensorShape{n, group, h, w, 8},
  224. TensorShape{group, 1, 1, kernel, kernel, 8}, TensorShape{});
  225. if (!no_bias) {
  226. args.emplace_back(
  227. param, TensorShape{n, group, h, w, 8},
  228. TensorShape{group, 1, 1, kernel, kernel, 8},
  229. TensorShape{1, group, 1, 1, 8});
  230. }
  231. if (!no_full_bias) {
  232. args.emplace_back(
  233. param, TensorShape{n, group, h, w, 8},
  234. TensorShape{group, 1, 1, kernel, kernel, 8},
  235. TensorShape{
  236. n, group, (h + 2 * param.pad_w - kernel) / stride + 1,
  237. (w + 2 * param.pad_w - kernel) / stride + 1, 8});
  238. }
  239. };
  240. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  241. if (!no_nonlinemode) {
  242. nonlinemode.emplace_back(NLMode::RELU);
  243. nonlinemode.emplace_back(NLMode::H_SWISH);
  244. }
  245. for (size_t n : {1, 2}) {
  246. for (auto nlmode : nonlinemode) {
  247. for (bool pad : {true}) {
  248. for (size_t group : {1, 2, 4, 7, 8, 16}) {
  249. for (size_t size : {4, 6, 7, 9, 20}) {
  250. for (size_t kern : kernel) {
  251. pack(n, group, size, size, kern, stride, nlmode, pad);
  252. }
  253. }
  254. }
  255. }
  256. for (bool pad : {false}) {
  257. for (size_t group : {1, 2, 7, 16}) {
  258. for (size_t size : {7, 9, 20}) {
  259. for (size_t kern : kernel) {
  260. pack(n, group, size, size, kern, stride, nlmode, pad);
  261. }
  262. }
  263. }
  264. }
  265. }
  266. }
  267. return args;
  268. }
  269. void checker_conv_bias_qint8x8x8(
  270. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name) {
  271. Checker<ConvBias> checker(handle);
  272. checker.set_before_exec_callback(
  273. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  274. #if MEGDNN_ARMV7
  275. checker.set_epsilon(1);
  276. #endif
  277. UniformIntRNG rng{-50, 50};
  278. checker.set_dtype(0, dtype::QuantizedS8(0.41113496f))
  279. .set_dtype(1, dtype::QuantizedS8(0.01887994f))
  280. .set_dtype(2, dtype::QuantizedS32(0.41113496f * 0.01887994f))
  281. .set_dtype(4, dtype::QuantizedS8(0.49550694f))
  282. .set_rng(0, &rng)
  283. .set_rng(1, &rng)
  284. .set_rng(2, &rng);
  285. for (auto&& arg : args) {
  286. checker.set_param(arg.param).execs({arg.src, arg.filter, arg.bias, {}, {}});
  287. }
  288. }
  289. void checker_conv_bias_qint8x8x32(
  290. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name) {
  291. Checker<ConvBias> checker(handle);
  292. UniformIntRNG rng{-50, 50};
  293. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  294. .set_dtype(1, dtype::QuantizedS8(2.5f))
  295. .set_dtype(2, dtype::QuantizedS32(6.25f))
  296. .set_dtype(4, {});
  297. checker.set_before_exec_callback(
  298. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  299. for (auto&& arg : args) {
  300. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  301. }
  302. }
  303. void checker_conv_bias_quint8x8x8(
  304. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name) {
  305. Checker<ConvBias> checker(handle);
  306. checker.set_before_exec_callback(
  307. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  308. UniformIntRNG rng(0, 255);
  309. checker.set_dtype(0, dtype::Quantized8Asymm(0.2f, 100))
  310. .set_dtype(1, dtype::Quantized8Asymm(0.2f, 120))
  311. .set_dtype(2, dtype::QuantizedS32(0.04f))
  312. .set_dtype(4, dtype::Quantized8Asymm(1.4f, 110))
  313. .set_rng(0, &rng)
  314. .set_rng(1, &rng)
  315. .set_rng(2, &rng);
  316. for (auto&& arg : args) {
  317. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  318. }
  319. }
  320. void checker_conv_bias_quint8x8x32(
  321. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name) {
  322. Checker<ConvBias> checker(handle);
  323. checker.set_before_exec_callback(
  324. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  325. NormalRNG rng(128.f);
  326. checker.set_rng(0, &rng).set_rng(1, &rng);
  327. checker.set_dtype(0, dtype::Quantized8Asymm(1.2f, (uint8_t)127))
  328. .set_dtype(1, dtype::Quantized8Asymm(1.3f, (uint8_t)129))
  329. .set_dtype(2, dtype::QuantizedS32(1.2 * 1.3))
  330. .set_dtype(4, {});
  331. for (auto&& arg : args) {
  332. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  333. }
  334. }
  335. void checker_conv_bias_int8x8x32_multi(
  336. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name) {
  337. Checker<ConvBias> checker(handle);
  338. checker.set_before_exec_callback(
  339. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  340. checker.set_dtype(0, dtype::Int8());
  341. checker.set_dtype(1, dtype::Int8());
  342. checker.set_dtype(2, dtype::Int32());
  343. checker.set_dtype(4, dtype::Int32());
  344. for (auto&& arg : args) {
  345. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  346. }
  347. }
  348. /**********************************F16 direct************************/
  349. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  350. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16) {
  351. NormalRNG rng(1);
  352. checker_conv_bias_f16(
  353. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), handle(),
  354. rng, "F16DIRECT", 0.03);
  355. }
  356. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1) {
  357. NormalRNG rng(1);
  358. checker_conv_bias_f16(
  359. get_conv_bias_args({2, 3, 5}, 1, false, false, false), handle(), rng,
  360. "F16STRD1", 0.03);
  361. }
  362. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP16_NCHW88_1) {
  363. NormalRNG rng(1);
  364. checker_conv_bias_f16(
  365. get_nchw88_channel_wise_args({2, 3}, 1, false, false, false), handle(), rng,
  366. "F16_CHANNEL_WISE_NCHW88", 0.03);
  367. }
  368. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP16_NCHW88_2) {
  369. NormalRNG rng(1);
  370. checker_conv_bias_f16(
  371. get_nchw88_channel_wise_args({5}, 1, false, false, false), handle(), rng,
  372. "F16_CHANNEL_WISE_NCHW88", 0.03);
  373. }
  374. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP16_NCHW88) {
  375. NormalRNG rng(1);
  376. checker_conv_bias_f16(
  377. get_nchw88_channel_wise_args({2, 3, 5}, 2, false, false, false), handle(),
  378. rng, "F16_CHANNEL_WISE_NCHW88", 0.03);
  379. }
  380. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_NCHW88_S1) {
  381. NormalRNG rng(1);
  382. checker_conv_bias_f16(
  383. get_nchw88_conv_bias_args({1, 2, 3, 5, 7}, FULL_NLMODE, ALL_BIASMODE, 1),
  384. handle(), rng, "F16_CONV_NCHW88_DIRECT", 0.03);
  385. }
  386. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_NCHW88_S2) {
  387. NormalRNG rng(1);
  388. checker_conv_bias_f16(
  389. get_nchw88_conv_bias_args({1, 2, 3, 5, 7}, FULL_NLMODE, ALL_BIASMODE, 2),
  390. handle(), rng, "F16_CONV_NCHW88_DIRECT", 0.03);
  391. }
  392. #endif
  393. /**********************************algo 8816 direct************************/
  394. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT) {
  395. checker_conv_bias_int8x8x16(
  396. get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(),
  397. "I8816DIRECT");
  398. }
  399. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2) {
  400. checker_conv_bias_int8x8x16(
  401. get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
  402. "I8816STRD2");
  403. }
  404. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_NCHW_NCHW44_S2) {
  405. checker_conv_bias_int8x8x16(
  406. get_nchw44_conv_bias_args(
  407. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_NO_BIASMODE, 2, false,
  408. true),
  409. handle(), "I8816_CONV_NCHW_NCHW44");
  410. }
  411. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_NCHW_NCHW44_S1) {
  412. checker_conv_bias_int8x8x16(
  413. get_nchw44_conv_bias_args(
  414. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_NO_BIASMODE, 1, false,
  415. true),
  416. handle(), "I8816_CONV_NCHW_NCHW44");
  417. }
  418. /**********************************algo 8-8-32 direct************************/
  419. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1) {
  420. checker_conv_bias_int8x8x32_multi(
  421. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  422. "S8STRD1");
  423. }
  424. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2) {
  425. checker_conv_bias_int8x8x32_multi(
  426. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  427. "S8STRD2");
  428. }
  429. TEST_F(ARM_COMMON_MULTI_THREADS,
  430. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT1_NCHW44) {
  431. checker_conv_bias_int8x8x32_multi(
  432. get_nchw44_channel_wise_args({2, 3, 5}, 1, false, true, true), handle(),
  433. "S8_CHAN_WISE_STRD1_NCHW44");
  434. }
  435. TEST_F(ARM_COMMON_MULTI_THREADS,
  436. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT2_NCHW44) {
  437. checker_conv_bias_int8x8x32_multi(
  438. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, true, true), handle(),
  439. "S8_CHAN_WISE_STRD2_NCHW44");
  440. }
  441. TEST_F(ARM_COMMON, CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT1_NCHW44) {
  442. Checker<ConvBias> checker(handle());
  443. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  444. "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"));
  445. checker.set_dtype(0, dtype::Int8());
  446. checker.set_dtype(1, dtype::Int8());
  447. checker.set_dtype(2, dtype::Int16());
  448. checker.set_dtype(4, dtype::Int16());
  449. auto args = get_nchw44_channel_wise_args({2, 3, 5}, 1, false, true, true);
  450. for (auto&& arg : args) {
  451. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  452. }
  453. }
  454. TEST_F(ARM_COMMON_MULTI_THREADS,
  455. CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT2_NCHW44) {
  456. Checker<ConvBias> checker(handle());
  457. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  458. "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"));
  459. checker.set_dtype(0, dtype::Int8());
  460. checker.set_dtype(1, dtype::Int8());
  461. checker.set_dtype(2, dtype::Int16());
  462. checker.set_dtype(4, dtype::Int16());
  463. auto args = get_nchw44_channel_wise_args({2, 3, 5}, 2, false, true, true);
  464. for (auto&& arg : args) {
  465. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  466. }
  467. }
  468. /********************************qint8 direct******************************/
  469. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1) {
  470. checker_conv_bias_qint8x8x8(
  471. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  472. handle(), "S8STRD1");
  473. }
  474. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2) {
  475. checker_conv_bias_qint8x8x8(
  476. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  477. handle(), "S8STRD2");
  478. }
  479. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) {
  480. checker_conv_bias_qint8x8x8(
  481. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE, ONLY_BR_BIASMODE, 1),
  482. handle(), "S8_NCHW44_DIRECT");
  483. }
  484. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8816) {
  485. checker_conv_bias_int8x8x16(
  486. get_nchw44_conv_bias_args(
  487. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 1),
  488. handle(), "S8x8x16_NCHW44_DIRECT");
  489. }
  490. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8816) {
  491. checker_conv_bias_int8x8x16(
  492. get_nchw44_conv_bias_args(
  493. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 2),
  494. handle(), "S8x8x16_NCHW44_DIRECT");
  495. }
  496. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8832) {
  497. checker_conv_bias_qint8x8x32(
  498. get_nchw44_conv_bias_args(
  499. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 1),
  500. handle(), "S8_NCHW44_DIRECT");
  501. }
  502. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8832) {
  503. checker_conv_bias_qint8x8x32(
  504. get_nchw44_conv_bias_args(
  505. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_NO_BIASMODE, 2),
  506. handle(), "S8_NCHW44_DIRECT");
  507. }
  508. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) {
  509. checker_conv_bias_qint8x8x8(
  510. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2),
  511. handle(), "S8_NCHW44_DIRECT");
  512. }
  513. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) {
  514. checker_conv_bias_qint8x8x8(
  515. get_nchw44_conv_bias_args(
  516. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1),
  517. handle(), "S8_NCHW44_DIRECT");
  518. checker_conv_bias_qint8x8x8(
  519. get_nchw44_channel_wise_args({2, 3, 5}, 1, false, false, true), handle(),
  520. "S8_CHAN_WISE_STRD1_NCHW44");
  521. }
  522. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) {
  523. checker_conv_bias_qint8x8x8(
  524. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, true), handle(),
  525. "S8_CHAN_WISE_STRD2_NCHW44");
  526. }
  527. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S1) {
  528. checker_conv_bias_qint8x8x8(
  529. get_nchw44_conv_bias_args(
  530. {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1, false, true),
  531. handle(), "S8_CONV_NCHW_NCHW44");
  532. }
  533. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S2) {
  534. checker_conv_bias_qint8x8x8(
  535. get_nchw44_conv_bias_args(
  536. {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2, false, true),
  537. handle(), "S8_CONV_NCHW_NCHW44");
  538. }
  539. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S1_F1) {
  540. checker_conv_bias_qint8x8x8(
  541. get_nchw44_conv_bias_args(
  542. {1}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1, false, true),
  543. handle(), "S8_CONV_NCHW_NCHW44");
  544. }
  545. /*****************************quint8 direct****************************/
  546. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1) {
  547. checker_conv_bias_quint8x8x8(
  548. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  549. handle(), "QU8STRD1");
  550. }
  551. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) {
  552. checker_conv_bias_quint8x8x8(
  553. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  554. handle(), "QU8STRD2");
  555. }
  556. /****************************dot qint8 direct*************************/
  557. #if MGB_ENABLE_DOT
  558. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S1) {
  559. checker_conv_bias_qint8x8x8(
  560. get_channel_wise_args({9, 11}, 1, false, true, true, true), handle(),
  561. "ARMDOTS8_DIRECT_CHANWISE_LARGE");
  562. }
  563. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S2) {
  564. checker_conv_bias_qint8x8x8(
  565. get_channel_wise_args({9, 11}, 2, false, true, true, true), handle(),
  566. "ARMDOTS8_DIRECT_CHANWISE_LARGE");
  567. }
  568. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_IM2COL_LARGE_S1) {
  569. checker_conv_bias_qint8x8x8(
  570. get_channel_wise_args({9}, 1, false, true, true, true), handle(),
  571. "ARMDOTS8_IM2COL_CHANWISE_LARGE");
  572. }
  573. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_IM2COL_LARGE_S2) {
  574. checker_conv_bias_qint8x8x8(
  575. get_channel_wise_args({9}, 2, false, true, true, true), handle(),
  576. "ARMDOTS8_IM2COL_CHANWISE_LARGE");
  577. }
  578. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) {
  579. auto args = get_nchw44_conv_bias_args(
  580. {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2, false, true);
  581. for (auto&& arg : args) {
  582. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  583. }
  584. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  585. args = get_nchw44_conv_bias_args(
  586. {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1, false, true);
  587. for (auto&& arg : args) {
  588. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  589. }
  590. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  591. }
  592. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44_S1_F1) {
  593. auto args = get_nchw44_conv_bias_args(
  594. {1}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1, false, true);
  595. for (auto&& arg : args) {
  596. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  597. }
  598. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  599. }
  600. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) {
  601. checker_conv_bias_qint8x8x8(
  602. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  603. handle(), "ARMDOTS8STRD1");
  604. }
  605. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_WITHDOTPROD) {
  606. checker_conv_bias_qint8x8x8(
  607. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  608. handle(), "ARMDOTS8STRD2");
  609. }
  610. /****************************dot 8-8-32 direct*************************/
  611. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT) {
  612. checker_conv_bias_qint8x8x32(
  613. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  614. "ARMDOTS8STRD1");
  615. }
  616. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT) {
  617. checker_conv_bias_qint8x8x32(
  618. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  619. "ARMDOTS8STRD2");
  620. }
  621. /******************************dot quint8*****************************/
  622. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD) {
  623. checker_conv_bias_quint8x8x8(
  624. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  625. handle(), "ARMDOTU8STRD1");
  626. }
  627. //! TODO: this test without test kernel size=3, add it will case buss error now
  628. //! in armv7
  629. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD) {
  630. checker_conv_bias_quint8x8x8(
  631. get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false), handle(),
  632. "ARMDOTU8STRD2");
  633. }
  634. /******************************dot quint8x8x32***********************/
  635. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1) {
  636. checker_conv_bias_quint8x8x32(
  637. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  638. "ARMDOTU8STRD1");
  639. }
  640. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2) {
  641. checker_conv_bias_quint8x8x32(
  642. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  643. "ARMDOTU8STRD2");
  644. }
  645. /******************************dot int8x8x8 nchw44 ***********************/
  646. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x8) {
  647. using namespace conv_bias;
  648. std::vector<TestArg> args =
  649. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE, ONLY_BR_BIASMODE, 1);
  650. for (auto&& arg : args)
  651. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  652. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  653. }
  654. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x32) {
  655. using namespace conv_bias;
  656. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  657. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1);
  658. for (auto&& arg : args)
  659. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  660. checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  661. }
  662. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_8x8x32) {
  663. using namespace conv_bias;
  664. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  665. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1);
  666. for (auto&& arg : args)
  667. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  668. checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  669. }
  670. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x8) {
  671. using namespace conv_bias;
  672. //! test qint8x8x8
  673. std::vector<TestArg> args =
  674. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2);
  675. for (auto&& arg : args)
  676. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  677. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  678. }
  679. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x32) {
  680. using namespace conv_bias;
  681. //! test qint8x8x8
  682. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  683. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 2);
  684. for (auto&& arg : args)
  685. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  686. checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  687. }
  688. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_8x8x32) {
  689. using namespace conv_bias;
  690. //! test qint8x8x8
  691. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  692. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 2);
  693. for (auto&& arg : args)
  694. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  695. checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  696. }
  697. #endif
  698. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) {
  699. using namespace conv_bias;
  700. std::vector<TestArg> args = get_winograd_args(3);
  701. Checker<ConvBiasForward> checker(handle());
  702. auto run = [&checker](
  703. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  704. DType C_dtype, DType D_dtype, const float eps) {
  705. for (auto&& arg : args) {
  706. checker.set_dtype(0, A_dtype)
  707. .set_dtype(1, B_dtype)
  708. .set_dtype(2, C_dtype)
  709. .set_dtype(4, D_dtype)
  710. .set_epsilon(eps)
  711. .set_param(arg.param)
  712. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  713. }
  714. };
  715. run(args, dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32(),
  716. 1e-3f);
  717. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  718. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  719. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  720. run(args, dtype::Float16(), dtype::Float16(), dtype::Float16(), dtype::Float16(),
  721. 0.35f);
  722. #endif
  723. }
  724. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F32_F43) {
  725. using namespace conv_bias;
  726. std::vector<TestArg> args = get_winograd_args(3);
  727. Checker<ConvBiasForward> checker(handle());
  728. check_winograd("1:4:32", checker, args);
  729. }
  730. //! uncomment it when low precision mode is ok
  731. #if 0
  732. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) {
  733. using namespace conv_bias;
  734. std::vector<TestArg> args =
  735. get_nchw44_conv_bias_args({3},QUAN_NLMODE,BR_AND_NO_BIASMODE,1);
  736. Checker<ConvBiasForward> checker(handle());
  737. check_winograd("4:7:16", checker, args, param::MatrixMul::Format::MK4,
  738. param::ConvBias::Format::NCHW44);
  739. }
  740. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44_WEIGHT_PREPROCESS) {
  741. using namespace conv_bias;
  742. std::vector<TestArg> args =
  743. get_nchw44_conv_bias_args({3},QUAN_NLMODE,BR_AND_NO_BIASMODE,1);
  744. Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
  745. handle());
  746. check_winograd("4:7:16", checker, args, param::MatrixMul::Format::MK4,
  747. param::ConvBias::Format::NCHW44);
  748. }
  749. #endif
  750. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1) {
  751. using namespace conv_bias;
  752. Checker<ConvBiasForward> checker(handle());
  753. auto run = [&checker](
  754. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  755. DType C_dtype, DType D_dtype, float eps) {
  756. for (auto&& arg : args) {
  757. checker.set_dtype(0, A_dtype)
  758. .set_dtype(1, B_dtype)
  759. .set_dtype(2, C_dtype)
  760. .set_dtype(4, D_dtype)
  761. .set_epsilon(eps)
  762. .set_param(arg.param)
  763. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  764. }
  765. };
  766. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  767. std::vector<TestArg> args_first_half(args.begin(), args.begin() + args.size() / 2);
  768. run(args_first_half, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
  769. dtype::Float32{}, 1e-3f);
  770. }
  771. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2) {
  772. using namespace conv_bias;
  773. Checker<ConvBiasForward> checker(handle());
  774. auto run = [&checker](
  775. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  776. DType C_dtype, DType D_dtype, float eps) {
  777. for (auto&& arg : args) {
  778. checker.set_dtype(0, A_dtype)
  779. .set_dtype(1, B_dtype)
  780. .set_dtype(2, C_dtype)
  781. .set_dtype(4, D_dtype)
  782. .set_epsilon(eps)
  783. .set_param(arg.param)
  784. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  785. }
  786. };
  787. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  788. std::vector<TestArg> args_second_half(args.begin() + args.size() / 2, args.end());
  789. run(args_second_half, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
  790. dtype::Float32{}, 1e-3f);
  791. }
  792. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  793. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F16) {
  794. using namespace conv_bias;
  795. Checker<ConvBiasForward> checker(handle());
  796. auto run = [&checker](
  797. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  798. DType C_dtype, DType D_dtype, float eps) {
  799. for (auto&& arg : args) {
  800. checker.set_dtype(0, A_dtype)
  801. .set_dtype(1, B_dtype)
  802. .set_dtype(2, C_dtype)
  803. .set_dtype(4, D_dtype)
  804. .set_epsilon(eps)
  805. .set_param(arg.param)
  806. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  807. }
  808. };
  809. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  810. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  811. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  812. run(args, dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
  813. 0.25);
  814. }
  815. #endif
  816. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_INT8) {
  817. using namespace conv_bias;
  818. Checker<ConvBiasForward> checker(handle());
  819. auto run = [&checker](
  820. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  821. DType C_dtype, DType D_dtype, float eps) {
  822. for (auto&& arg : args) {
  823. checker.set_dtype(0, A_dtype)
  824. .set_dtype(1, B_dtype)
  825. .set_dtype(2, C_dtype)
  826. .set_dtype(4, D_dtype)
  827. .set_epsilon(eps)
  828. .set_param(arg.param)
  829. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  830. }
  831. };
  832. #if MEGDNN_AARCH64
  833. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  834. #else
  835. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  836. #endif
  837. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  838. ssprintf("WINOGRAD:%s:8:2:32", matmul_name).c_str()));
  839. std::vector<TestArg> quantized_args = get_quantized_winograd_mk_packed_args(8);
  840. UniformIntRNG int_rng{-50, 50};
  841. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  842. run(quantized_args, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
  843. dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), 1e-3);
  844. }
  845. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) {
  846. using namespace conv_bias;
  847. Checker<ConvBiasForward> checker(handle());
  848. auto run = [&checker](
  849. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  850. DType C_dtype, DType D_dtype, float eps) {
  851. for (auto&& arg : args) {
  852. checker.set_dtype(0, A_dtype)
  853. .set_dtype(1, B_dtype)
  854. .set_dtype(2, C_dtype)
  855. .set_dtype(4, D_dtype)
  856. .set_epsilon(eps)
  857. .set_param(arg.param)
  858. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  859. }
  860. };
  861. #if MEGDNN_AARCH64
  862. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  863. #else
  864. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  865. #endif
  866. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  867. ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str()));
  868. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4);
  869. UniformIntRNG int_rng{-50, 50};
  870. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  871. run(quantized_args, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
  872. dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), 1e-3);
  873. }
  874. TEST_F(ARM_COMMON_MULTI_THREADS,
  875. CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32_GROUPMODE) {
  876. using namespace conv_bias;
  877. Checker<ConvBiasForward> checker(handle());
  878. auto run = [&checker](
  879. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  880. DType C_dtype, DType D_dtype, float eps) {
  881. for (auto&& arg : args) {
  882. checker.set_dtype(0, A_dtype)
  883. .set_dtype(1, B_dtype)
  884. .set_dtype(2, C_dtype)
  885. .set_dtype(4, D_dtype)
  886. .set_epsilon(eps)
  887. .set_param(arg.param)
  888. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  889. }
  890. };
  891. float epsilon = 0.001;
  892. #if MEGDNN_AARCH64
  893. const char* matmul_name = "AARCH64_F32_MK4_4x16";
  894. #else
  895. const char* matmul_name = "ARMV7_F32_MK4_4x8";
  896. #endif
  897. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  898. ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str()));
  899. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4, true, true);
  900. UniformIntRNG int_rng{-50, 50};
  901. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  902. run(quantized_args, dtype::QuantizedS8(0.41113496f),
  903. dtype::QuantizedS8(0.01887994f), dtype::QuantizedS32(0.41113496f * 0.01887994f),
  904. dtype::QuantizedS8(0.49550694f), epsilon);
  905. }
  906. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPMODE) {
  907. using namespace conv_bias;
  908. Checker<ConvBiasForward> checker(handle());
  909. auto run = [&checker](
  910. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  911. DType C_dtype, DType D_dtype, float eps) {
  912. for (auto&& arg : args) {
  913. checker.set_dtype(0, A_dtype)
  914. .set_dtype(1, B_dtype)
  915. .set_dtype(2, C_dtype)
  916. .set_dtype(4, D_dtype)
  917. .set_epsilon(eps)
  918. .set_param(arg.param)
  919. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  920. }
  921. };
  922. #if MEGDNN_AARCH64
  923. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  924. #else
  925. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  926. #endif
  927. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  928. ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str()));
  929. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4, false, true);
  930. UniformIntRNG int_rng{-50, 50};
  931. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  932. run(quantized_args, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
  933. dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), 1e-3);
  934. }
  935. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32) {
  936. using namespace conv_bias;
  937. Checker<ConvBiasForward> checker(handle());
  938. auto run = [&checker](
  939. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  940. DType C_dtype, DType D_dtype, float eps) {
  941. for (auto&& arg : args) {
  942. checker.set_dtype(0, A_dtype)
  943. .set_dtype(1, B_dtype)
  944. .set_dtype(2, C_dtype)
  945. .set_dtype(4, D_dtype)
  946. .set_epsilon(eps)
  947. .set_param(arg.param)
  948. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  949. }
  950. };
  951. float epsilon = 0.001;
  952. #if MEGDNN_AARCH64
  953. const char* matmul_name = "AARCH64_F32_MK4_4x16";
  954. #else
  955. const char* matmul_name = "ARMV7_F32_MK4_4x8";
  956. #endif
  957. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  958. ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str()));
  959. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4, true);
  960. UniformIntRNG int_rng{-50, 50};
  961. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  962. run(quantized_args, dtype::QuantizedS8(0.41113496f),
  963. dtype::QuantizedS8(0.01887994f), dtype::QuantizedS32(0.41113496f * 0.01887994f),
  964. dtype::QuantizedS8(0.49550694f), epsilon);
  965. }
  966. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  967. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F23) {
  968. using namespace conv_bias;
  969. std::vector<TestArg> args = get_winograd_mk_packed_args();
  970. Checker<ConvBiasForward> checker(handle());
  971. check_winograd_fp16("1:2:32", checker, args, NULL, 0.08);
  972. }
  973. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_1) {
  974. using namespace conv_bias;
  975. std::vector<TestArg> args = get_winograd_args(5);
  976. std::vector<TestArg> args_head_half(args.begin(), args.begin() + args.size() / 2);
  977. Checker<ConvBiasForward> checker(handle());
  978. //! fp16 range -1.0 ~ 1.0
  979. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  980. check_winograd_fp16("1:4:32", checker, args_head_half, rng, 0.25);
  981. }
  982. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_2) {
  983. using namespace conv_bias;
  984. std::vector<TestArg> args = get_winograd_args(5);
  985. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2, args.end());
  986. Checker<ConvBiasForward> checker(handle());
  987. //! fp16 range -1.0 ~ 1.0
  988. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  989. check_winograd_fp16("1:4:32", checker, args_back_half, rng, 0.25);
  990. }
  991. //! FIXME: This test may be failed if run `ARM_COMMON.CONV_BIAS_WINOGRAD*`, but
  992. //! it will pass when run single testcase
  993. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F63) {
  994. using namespace conv_bias;
  995. std::vector<TestArg> args = get_winograd_args(3);
  996. Checker<ConvBiasForward> checker(handle());
  997. //! fp16 range -1.0 ~ 1.0
  998. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  999. check_winograd_fp16("1:6:32", checker, args, rng, 0.3);
  1000. }
  1001. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_1) {
  1002. using namespace conv_bias;
  1003. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  1004. std::vector<TestArg> args_head_half(args.begin(), args.begin() + args.size() / 2);
  1005. Checker<ConvBiasForward> checker(handle());
  1006. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1007. check_winograd_fp16(
  1008. "8:2:32", checker, args_head_half, rng, 0.25,
  1009. param::MatrixMul::Format::MK8);
  1010. }
  1011. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_2) {
  1012. using namespace conv_bias;
  1013. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  1014. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2, args.end());
  1015. Checker<ConvBiasForward> checker(handle());
  1016. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1017. check_winograd_fp16(
  1018. "8:2:32", checker, args_back_half, rng, 0.25,
  1019. param::MatrixMul::Format::MK8);
  1020. }
  1021. #endif
  1022. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_INT8_8X8) {
  1023. using namespace conv_bias;
  1024. std::vector<TestArg> args = get_quantized_winograd_mk_packed_args(8);
  1025. Checker<ConvBiasForward> checker(handle());
  1026. UniformIntRNG rng{-50, 50};
  1027. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  1028. .set_dtype(1, dtype::QuantizedS8(2.5f))
  1029. .set_dtype(2, dtype::QuantizedS32(6.25f))
  1030. .set_dtype(4, dtype::QuantizedS8(60.25f))
  1031. .set_rng(0, &rng)
  1032. .set_rng(1, &rng)
  1033. .set_rng(2, &rng);
  1034. check_winograd("8:2:32", checker, args, param::MatrixMul::Format::MK8);
  1035. }
  1036. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_INT8_8X8_WEIGHT_PREPROCESS) {
  1037. using namespace conv_bias;
  1038. std::vector<TestArg> args = get_quantized_winograd_mk_packed_args(8);
  1039. Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
  1040. handle());
  1041. UniformIntRNG rng{-50, 50};
  1042. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  1043. .set_dtype(1, dtype::QuantizedS8(2.5f))
  1044. .set_dtype(2, dtype::QuantizedS32(6.25f))
  1045. .set_dtype(4, dtype::QuantizedS8(60.25f))
  1046. .set_rng(0, &rng)
  1047. .set_rng(1, &rng)
  1048. .set_rng(2, &rng);
  1049. check_winograd("8:2:32", checker, args, param::MatrixMul::Format::MK8);
  1050. }
  1051. // clang-format on
  1052. // vim: syntax=cpp.doxygen