You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bias_multi_thread.cpp 46 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157
  1. #include "megdnn/dtype.h"
  2. #include "test/arm_common/fixture.h"
  3. #include "test/common/benchmarker.h"
  4. #include "test/common/conv_bias.h"
  5. #include "test/arm_common/cpuinfo_help.h"
  6. using namespace megdnn;
  7. using namespace test;
  8. using namespace conv_bias;
  9. std::vector<conv_bias::TestArg> get_int8_quint8_conv_bias_args(
  10. std::vector<size_t> kernel, size_t stride, bool no_pad, bool no_bias,
  11. bool no_nonlinemode) {
  12. using namespace conv_bias;
  13. using Param = param::ConvBias;
  14. using NLMode = param::ConvBias::NonlineMode;
  15. std::vector<TestArg> args;
  16. auto pack = [&](size_t n, size_t oc, size_t ic, size_t w, size_t h, size_t kernel,
  17. size_t stride, NLMode nlmode) {
  18. Param param;
  19. param.stride_h = stride;
  20. param.stride_w = stride;
  21. if (!no_pad) {
  22. param.pad_h = kernel / 2;
  23. param.pad_w = kernel / 2;
  24. } else {
  25. param.pad_h = 0;
  26. param.pad_w = 0;
  27. }
  28. param.nonlineMode = nlmode;
  29. args.emplace_back(
  30. param, TensorShape{n, ic, h, w}, TensorShape{oc, ic, kernel, kernel},
  31. TensorShape{});
  32. if (!no_bias) {
  33. args.emplace_back(
  34. param, TensorShape{n, ic, h, w},
  35. TensorShape{oc, ic, kernel, kernel}, TensorShape{1, oc, 1, 1});
  36. }
  37. };
  38. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  39. if (!no_nonlinemode) {
  40. nonlinemode.emplace_back(NLMode::RELU);
  41. nonlinemode.emplace_back(NLMode::H_SWISH);
  42. }
  43. for (size_t n : {1, 2}) {
  44. for (auto nlmode : nonlinemode) {
  45. for (size_t ic : {1, 3, 7}) {
  46. for (size_t oc : {1, 3, 7}) {
  47. for (size_t size : {4, 6, 8, 14, 16, 18}) {
  48. for (size_t kern : kernel) {
  49. pack(n, oc, ic, size, size, kern, stride, nlmode);
  50. }
  51. }
  52. }
  53. }
  54. }
  55. }
  56. return args;
  57. }
  58. std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args(
  59. std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode,
  60. bool no_full_bias) {
  61. using namespace conv_bias;
  62. using Param = param::ConvBias;
  63. using NLMode = param::ConvBias::NonlineMode;
  64. std::vector<TestArg> args;
  65. auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
  66. size_t stride, NLMode nlmode, bool pad) {
  67. Param param;
  68. param.stride_h = stride;
  69. param.stride_w = stride;
  70. if (pad) {
  71. param.pad_h = kernel / 2;
  72. param.pad_w = kernel / 2;
  73. } else {
  74. param.pad_h = 0;
  75. param.pad_w = 0;
  76. }
  77. param.nonlineMode = nlmode;
  78. param.format = param::ConvBias::Format::NCHW44;
  79. param.sparse = param::ConvBias::Sparse::GROUP;
  80. args.emplace_back(
  81. param, TensorShape{n, group, h, w, 4},
  82. TensorShape{group, 1, 1, kernel, kernel, 4}, TensorShape{});
  83. if (!no_bias) {
  84. args.emplace_back(
  85. param, TensorShape{n, group, h, w, 4},
  86. TensorShape{group, 1, 1, kernel, kernel, 4},
  87. TensorShape{1, group, 1, 1, 4});
  88. }
  89. if (!no_full_bias) {
  90. args.emplace_back(
  91. param, TensorShape{n, group, h, w, 4},
  92. TensorShape{group, 1, 1, kernel, kernel, 4},
  93. TensorShape{
  94. n, group, (h + 2 * param.pad_w - kernel) / stride + 1,
  95. (w + 2 * param.pad_w - kernel) / stride + 1, 4});
  96. }
  97. };
  98. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  99. if (!no_nonlinemode) {
  100. nonlinemode.emplace_back(NLMode::RELU);
  101. nonlinemode.emplace_back(NLMode::H_SWISH);
  102. }
  103. for (size_t n : {1, 2}) {
  104. for (auto nlmode : nonlinemode) {
  105. for (bool pad : {true}) {
  106. for (size_t group : {1, 2, 4, 7, 16}) {
  107. for (size_t size : {4, 6, 7, 9, 20}) {
  108. for (size_t kern : kernel) {
  109. pack(n, group, size, size, kern, stride, nlmode, pad);
  110. }
  111. }
  112. }
  113. }
  114. for (bool pad : {false}) {
  115. for (size_t group : {1, 2, 7, 16}) {
  116. for (size_t size : {7, 9, 20}) {
  117. for (size_t kern : kernel) {
  118. pack(n, group, size, size, kern, stride, nlmode, pad);
  119. }
  120. }
  121. }
  122. }
  123. }
  124. }
  125. return args;
  126. }
  127. std::vector<conv_bias::TestArg> get_channel_wise_args(
  128. std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode,
  129. bool no_full_bias, bool support_relu) {
  130. using namespace conv_bias;
  131. using Param = param::ConvBias;
  132. using NLMode = param::ConvBias::NonlineMode;
  133. std::vector<TestArg> args;
  134. auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
  135. size_t stride, NLMode nlmode, bool pad) {
  136. Param param;
  137. param.stride_h = stride;
  138. param.stride_w = stride;
  139. if (pad) {
  140. param.pad_h = kernel / 2;
  141. param.pad_w = kernel / 2;
  142. } else {
  143. param.pad_h = 0;
  144. param.pad_w = 0;
  145. }
  146. param.nonlineMode = nlmode;
  147. param.format = param::ConvBias::Format::NCHW;
  148. param.sparse = param::ConvBias::Sparse::GROUP;
  149. args.emplace_back(
  150. param, TensorShape{n, group, h, w},
  151. TensorShape{group, 1, 1, kernel, kernel}, TensorShape{});
  152. if (!no_bias) {
  153. args.emplace_back(
  154. param, TensorShape{n, group, h, w},
  155. TensorShape{group, 1, 1, kernel, kernel},
  156. TensorShape{1, group, 1, 1});
  157. }
  158. if (!no_full_bias) {
  159. args.emplace_back(
  160. param, TensorShape{n, group, h, w},
  161. TensorShape{group, 1, 1, kernel, kernel},
  162. TensorShape{
  163. n, group, (h + 2 * param.pad_w - kernel) / stride + 1,
  164. (w + 2 * param.pad_w - kernel) / stride + 1});
  165. }
  166. };
  167. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  168. if (!no_nonlinemode) {
  169. nonlinemode.emplace_back(NLMode::RELU);
  170. nonlinemode.emplace_back(NLMode::H_SWISH);
  171. } else if (support_relu) {
  172. nonlinemode.emplace_back(NLMode::RELU);
  173. }
  174. for (size_t n : {1, 2}) {
  175. for (auto nlmode : nonlinemode) {
  176. for (bool pad : {true}) {
  177. for (size_t group : {1, 3, 7}) {
  178. for (size_t size : {4, 6, 7, 9, 16, 20, 32, 55}) {
  179. for (size_t kern : kernel) {
  180. pack(n, group, size, size, kern, stride, nlmode, pad);
  181. }
  182. }
  183. }
  184. }
  185. for (bool pad : {false}) {
  186. for (size_t group : {7}) {
  187. for (size_t size : {37}) {
  188. for (size_t kern : kernel) {
  189. if (size < kern)
  190. continue;
  191. pack(n, group, size, size, kern, stride, nlmode, pad);
  192. }
  193. }
  194. }
  195. }
  196. }
  197. }
  198. return args;
  199. }
  200. std::vector<conv_bias::TestArg> get_nchw88_channel_wise_args(
  201. std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode,
  202. bool no_full_bias) {
  203. using namespace conv_bias;
  204. using Param = param::ConvBias;
  205. using NLMode = param::ConvBias::NonlineMode;
  206. std::vector<TestArg> args;
  207. auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
  208. size_t stride, NLMode nlmode, bool pad) {
  209. Param param;
  210. param.stride_h = stride;
  211. param.stride_w = stride;
  212. if (pad) {
  213. param.pad_h = kernel / 2;
  214. param.pad_w = kernel / 2;
  215. } else {
  216. param.pad_h = 0;
  217. param.pad_w = 0;
  218. }
  219. param.nonlineMode = nlmode;
  220. param.format = param::ConvBias::Format::NCHW88;
  221. param.sparse = param::ConvBias::Sparse::GROUP;
  222. args.emplace_back(
  223. param, TensorShape{n, group, h, w, 8},
  224. TensorShape{group, 1, 1, kernel, kernel, 8}, TensorShape{});
  225. if (!no_bias) {
  226. args.emplace_back(
  227. param, TensorShape{n, group, h, w, 8},
  228. TensorShape{group, 1, 1, kernel, kernel, 8},
  229. TensorShape{1, group, 1, 1, 8});
  230. }
  231. if (!no_full_bias) {
  232. args.emplace_back(
  233. param, TensorShape{n, group, h, w, 8},
  234. TensorShape{group, 1, 1, kernel, kernel, 8},
  235. TensorShape{
  236. n, group, (h + 2 * param.pad_w - kernel) / stride + 1,
  237. (w + 2 * param.pad_w - kernel) / stride + 1, 8});
  238. }
  239. };
  240. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  241. if (!no_nonlinemode) {
  242. nonlinemode.emplace_back(NLMode::RELU);
  243. nonlinemode.emplace_back(NLMode::H_SWISH);
  244. }
  245. for (size_t n : {1, 2}) {
  246. for (auto nlmode : nonlinemode) {
  247. for (bool pad : {true}) {
  248. for (size_t group : {1, 2, 4, 7, 8, 16}) {
  249. for (size_t size : {4, 6, 7, 9, 20}) {
  250. for (size_t kern : kernel) {
  251. pack(n, group, size, size, kern, stride, nlmode, pad);
  252. }
  253. }
  254. }
  255. }
  256. for (bool pad : {false}) {
  257. for (size_t group : {1, 2, 7, 16}) {
  258. for (size_t size : {7, 9, 20}) {
  259. for (size_t kern : kernel) {
  260. pack(n, group, size, size, kern, stride, nlmode, pad);
  261. }
  262. }
  263. }
  264. }
  265. }
  266. }
  267. return args;
  268. }
  269. void checker_conv_bias_qint8x8x8(
  270. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name) {
  271. Checker<ConvBias> checker(handle);
  272. checker.set_before_exec_callback(
  273. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  274. #if MEGDNN_ARMV7
  275. checker.set_epsilon(1);
  276. #endif
  277. UniformIntRNG rng{-50, 50};
  278. checker.set_dtype(0, dtype::QuantizedS8(0.41113496f))
  279. .set_dtype(1, dtype::QuantizedS8(0.01887994f))
  280. .set_dtype(2, dtype::QuantizedS32(0.41113496f * 0.01887994f))
  281. .set_dtype(4, dtype::QuantizedS8(0.49550694f))
  282. .set_rng(0, &rng)
  283. .set_rng(1, &rng)
  284. .set_rng(2, &rng);
  285. for (auto&& arg : args) {
  286. checker.set_param(arg.param).execs({arg.src, arg.filter, arg.bias, {}, {}});
  287. }
  288. }
  289. void checker_conv_bias_qint8x8x32(
  290. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name) {
  291. Checker<ConvBias> checker(handle);
  292. UniformIntRNG rng{-50, 50};
  293. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  294. .set_dtype(1, dtype::QuantizedS8(2.5f))
  295. .set_dtype(2, dtype::QuantizedS32(6.25f))
  296. .set_dtype(4, {});
  297. checker.set_before_exec_callback(
  298. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  299. for (auto&& arg : args) {
  300. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  301. }
  302. }
  303. void checker_conv_bias_quint8x8x8(
  304. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name) {
  305. Checker<ConvBias> checker(handle);
  306. checker.set_before_exec_callback(
  307. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  308. UniformIntRNG rng(0, 255);
  309. checker.set_dtype(0, dtype::Quantized8Asymm(0.2f, 100))
  310. .set_dtype(1, dtype::Quantized8Asymm(0.2f, 120))
  311. .set_dtype(2, dtype::QuantizedS32(0.04f))
  312. .set_dtype(4, dtype::Quantized8Asymm(1.4f, 110))
  313. .set_rng(0, &rng)
  314. .set_rng(1, &rng)
  315. .set_rng(2, &rng);
  316. for (auto&& arg : args) {
  317. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  318. }
  319. }
  320. void checker_conv_bias_quint8x8x32(
  321. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name) {
  322. Checker<ConvBias> checker(handle);
  323. checker.set_before_exec_callback(
  324. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  325. NormalRNG rng(128.f);
  326. checker.set_rng(0, &rng).set_rng(1, &rng);
  327. checker.set_dtype(0, dtype::Quantized8Asymm(1.2f, (uint8_t)127))
  328. .set_dtype(1, dtype::Quantized8Asymm(1.3f, (uint8_t)129))
  329. .set_dtype(2, dtype::QuantizedS32(1.2 * 1.3))
  330. .set_dtype(4, {});
  331. for (auto&& arg : args) {
  332. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  333. }
  334. }
  335. void checker_conv_bias_int8x8x32_multi(
  336. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name) {
  337. Checker<ConvBias> checker(handle);
  338. checker.set_before_exec_callback(
  339. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  340. checker.set_dtype(0, dtype::Int8());
  341. checker.set_dtype(1, dtype::Int8());
  342. checker.set_dtype(2, dtype::Int32());
  343. checker.set_dtype(4, dtype::Int32());
  344. for (auto&& arg : args) {
  345. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  346. }
  347. }
  348. /**********************************F16 direct************************/
  349. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  350. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16) {
  351. NormalRNG rng(1);
  352. checker_conv_bias_f16(
  353. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), handle(),
  354. rng, "F16DIRECT", 0.03);
  355. }
  356. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1) {
  357. NormalRNG rng(1);
  358. checker_conv_bias_f16(
  359. get_conv_bias_args({2, 3, 5}, 1, false, false, false), handle(), rng,
  360. "F16STRD1", 0.03);
  361. }
  362. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP16_NCHW88_1) {
  363. NormalRNG rng(1);
  364. checker_conv_bias_f16(
  365. get_nchw88_channel_wise_args({2, 3}, 1, false, false, false), handle(), rng,
  366. "F16_CHANNEL_WISE_NCHW88", 0.03);
  367. }
  368. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP16_NCHW88_2) {
  369. NormalRNG rng(1);
  370. checker_conv_bias_f16(
  371. get_nchw88_channel_wise_args({5}, 1, false, false, false), handle(), rng,
  372. "F16_CHANNEL_WISE_NCHW88", 0.03);
  373. }
  374. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP16_NCHW88) {
  375. NormalRNG rng(1);
  376. checker_conv_bias_f16(
  377. get_nchw88_channel_wise_args({2, 3, 5}, 2, false, false, false), handle(),
  378. rng, "F16_CHANNEL_WISE_NCHW88", 0.03);
  379. }
  380. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_NCHW88_S1) {
  381. NormalRNG rng(1);
  382. checker_conv_bias_f16(
  383. get_nchw88_conv_bias_args({1, 2, 3, 5, 7}, FULL_NLMODE, ALL_BIASMODE, 1),
  384. handle(), rng, "F16_CONV_NCHW88_DIRECT", 0.03);
  385. }
  386. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_NCHW88_S2) {
  387. NormalRNG rng(1);
  388. checker_conv_bias_f16(
  389. get_nchw88_conv_bias_args({1, 2, 3, 5, 7}, FULL_NLMODE, ALL_BIASMODE, 2),
  390. handle(), rng, "F16_CONV_NCHW88_DIRECT", 0.03);
  391. }
  392. #endif
  393. /**********************************algo 8816 direct************************/
  394. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT) {
  395. checker_conv_bias_int8x8x16(
  396. get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(),
  397. "I8816DIRECT");
  398. }
  399. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2) {
  400. checker_conv_bias_int8x8x16(
  401. get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
  402. "I8816STRD2");
  403. }
  404. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_NCHW_NCHW44_S2) {
  405. checker_conv_bias_int8x8x16(
  406. get_nchw44_conv_bias_args(
  407. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_NO_BIASMODE, 2, false,
  408. true),
  409. handle(), "I8816_CONV_NCHW_NCHW44");
  410. }
  411. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_NCHW_NCHW44_S1) {
  412. checker_conv_bias_int8x8x16(
  413. get_nchw44_conv_bias_args(
  414. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_NO_BIASMODE, 1, false,
  415. true),
  416. handle(), "I8816_CONV_NCHW_NCHW44");
  417. }
  418. /**********************************algo 8-8-32 direct************************/
  419. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1) {
  420. checker_conv_bias_int8x8x32_multi(
  421. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  422. "S8STRD1");
  423. }
  424. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2) {
  425. checker_conv_bias_int8x8x32_multi(
  426. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  427. "S8STRD2");
  428. }
  429. TEST_F(ARM_COMMON_MULTI_THREADS,
  430. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT1_NCHW44) {
  431. checker_conv_bias_int8x8x32_multi(
  432. get_nchw44_channel_wise_args({2, 3, 5}, 1, false, true, true), handle(),
  433. "S8_CHAN_WISE_STRD1_NCHW44");
  434. }
  435. TEST_F(ARM_COMMON_MULTI_THREADS,
  436. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT2_NCHW44) {
  437. checker_conv_bias_int8x8x32_multi(
  438. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, true, true), handle(),
  439. "S8_CHAN_WISE_STRD2_NCHW44");
  440. }
  441. TEST_F(ARM_COMMON, CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT1_NCHW44) {
  442. Checker<ConvBias> checker(handle());
  443. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  444. "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"));
  445. checker.set_dtype(0, dtype::Int8());
  446. checker.set_dtype(1, dtype::Int8());
  447. checker.set_dtype(2, dtype::Int16());
  448. checker.set_dtype(4, dtype::Int16());
  449. auto args = get_nchw44_channel_wise_args({2, 3, 5}, 1, false, true, true);
  450. for (auto&& arg : args) {
  451. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  452. }
  453. }
  454. TEST_F(ARM_COMMON_MULTI_THREADS,
  455. CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT2_NCHW44) {
  456. Checker<ConvBias> checker(handle());
  457. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  458. "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"));
  459. checker.set_dtype(0, dtype::Int8());
  460. checker.set_dtype(1, dtype::Int8());
  461. checker.set_dtype(2, dtype::Int16());
  462. checker.set_dtype(4, dtype::Int16());
  463. auto args = get_nchw44_channel_wise_args({2, 3, 5}, 2, false, true, true);
  464. for (auto&& arg : args) {
  465. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  466. }
  467. }
  468. /********************************qint8 direct******************************/
  469. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1) {
  470. checker_conv_bias_qint8x8x8(
  471. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  472. handle(), "S8STRD1");
  473. }
  474. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2) {
  475. checker_conv_bias_qint8x8x8(
  476. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  477. handle(), "S8STRD2");
  478. }
  479. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) {
  480. checker_conv_bias_qint8x8x8(
  481. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE, ONLY_BR_BIASMODE, 1),
  482. handle(), "S8_NCHW44_DIRECT");
  483. }
  484. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8816) {
  485. checker_conv_bias_int8x8x16(
  486. get_nchw44_conv_bias_args(
  487. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 1),
  488. handle(), "S8x8x16_NCHW44_DIRECT");
  489. }
  490. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8816) {
  491. checker_conv_bias_int8x8x16(
  492. get_nchw44_conv_bias_args(
  493. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 2),
  494. handle(), "S8x8x16_NCHW44_DIRECT");
  495. }
  496. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8832) {
  497. checker_conv_bias_qint8x8x32(
  498. get_nchw44_conv_bias_args(
  499. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 1),
  500. handle(), "S8_NCHW44_DIRECT");
  501. }
  502. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8832) {
  503. checker_conv_bias_qint8x8x32(
  504. get_nchw44_conv_bias_args(
  505. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_NO_BIASMODE, 2),
  506. handle(), "S8_NCHW44_DIRECT");
  507. }
  508. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) {
  509. checker_conv_bias_qint8x8x8(
  510. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2),
  511. handle(), "S8_NCHW44_DIRECT");
  512. }
  513. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) {
  514. checker_conv_bias_qint8x8x8(
  515. get_nchw44_conv_bias_args(
  516. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1),
  517. handle(), "S8_NCHW44_DIRECT");
  518. checker_conv_bias_qint8x8x8(
  519. get_nchw44_channel_wise_args({2, 3, 5}, 1, false, false, true), handle(),
  520. "S8_CHAN_WISE_STRD1_NCHW44");
  521. }
  522. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) {
  523. checker_conv_bias_qint8x8x8(
  524. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, true), handle(),
  525. "S8_CHAN_WISE_STRD2_NCHW44");
  526. }
  527. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S1) {
  528. checker_conv_bias_qint8x8x8(
  529. get_nchw44_conv_bias_args(
  530. {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1, false, true),
  531. handle(), "S8_CONV_NCHW_NCHW44");
  532. }
  533. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S2) {
  534. checker_conv_bias_qint8x8x8(
  535. get_nchw44_conv_bias_args(
  536. {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2, false, true),
  537. handle(), "S8_CONV_NCHW_NCHW44");
  538. }
  539. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S1_F1) {
  540. checker_conv_bias_qint8x8x8(
  541. get_nchw44_conv_bias_args(
  542. {1}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1, false, true),
  543. handle(), "S8_CONV_NCHW_NCHW44");
  544. }
  545. /*****************************quint8 direct****************************/
  546. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1) {
  547. checker_conv_bias_quint8x8x8(
  548. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  549. handle(), "QU8STRD1");
  550. }
  551. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) {
  552. checker_conv_bias_quint8x8x8(
  553. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  554. handle(), "QU8STRD2");
  555. }
  556. /****************************dot qint8 direct*************************/
  557. #if MGB_ENABLE_DOT
  558. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S1) {
  559. checker_conv_bias_qint8x8x8(
  560. get_channel_wise_args({9}, 1, false, true, true, true), handle(),
  561. "ARMDOTS8_DIRECT_CHANWISE_LARGE");
  562. }
  563. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S2) {
  564. checker_conv_bias_qint8x8x8(
  565. get_channel_wise_args({9}, 2, false, true, true, true), handle(),
  566. "ARMDOTS8_DIRECT_CHANWISE_LARGE");
  567. }
  568. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_IM2COL_LARGE_S1) {
  569. checker_conv_bias_qint8x8x8(
  570. get_channel_wise_args({9}, 1, false, true, true, true), handle(),
  571. "ARMDOTS8_IM2COL_CHANWISE_LARGE");
  572. }
  573. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_IM2COL_LARGE_S2) {
  574. checker_conv_bias_qint8x8x8(
  575. get_channel_wise_args({9}, 2, false, true, true, true), handle(),
  576. "ARMDOTS8_IM2COL_CHANWISE_LARGE");
  577. }
  578. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) {
  579. auto args = get_nchw44_conv_bias_args(
  580. {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2, false, true);
  581. for (auto&& arg : args) {
  582. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  583. }
  584. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  585. args = get_nchw44_conv_bias_args(
  586. {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1, false, true);
  587. for (auto&& arg : args) {
  588. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  589. }
  590. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  591. }
  592. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44_S1_F1) {
  593. auto args = get_nchw44_conv_bias_args(
  594. {1}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1, false, true);
  595. for (auto&& arg : args) {
  596. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  597. }
  598. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  599. }
  600. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) {
  601. checker_conv_bias_qint8x8x8(
  602. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  603. handle(), "ARMDOTS8STRD1");
  604. }
  605. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_WITHDOTPROD) {
  606. checker_conv_bias_qint8x8x8(
  607. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  608. handle(), "ARMDOTS8STRD2");
  609. }
  610. /****************************dot 8-8-32 direct*************************/
  611. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT) {
  612. checker_conv_bias_qint8x8x32(
  613. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  614. "ARMDOTS8STRD1");
  615. }
  616. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT) {
  617. checker_conv_bias_qint8x8x32(
  618. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  619. "ARMDOTS8STRD2");
  620. }
  621. /******************************dot quint8*****************************/
  622. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD) {
  623. checker_conv_bias_quint8x8x8(
  624. get_int8_quint8_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  625. handle(), "ARMDOTU8STRD1");
  626. }
  627. //! TODO: this test without test kernel size=3, add it will case buss error now
  628. //! in armv7
  629. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD) {
  630. checker_conv_bias_quint8x8x8(
  631. get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false), handle(),
  632. "ARMDOTU8STRD2");
  633. }
  634. /******************************dot quint8x8x32***********************/
  635. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1) {
  636. checker_conv_bias_quint8x8x32(
  637. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  638. "ARMDOTU8STRD1");
  639. }
  640. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2) {
  641. checker_conv_bias_quint8x8x32(
  642. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  643. "ARMDOTU8STRD2");
  644. }
  645. /******************************dot int8x8x8 nchw44 ***********************/
  646. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x8) {
  647. using namespace conv_bias;
  648. std::vector<TestArg> args =
  649. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE, ONLY_BR_BIASMODE, 1);
  650. for (auto&& arg : args)
  651. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  652. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  653. }
  654. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x32) {
  655. using namespace conv_bias;
  656. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  657. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1);
  658. for (auto&& arg : args)
  659. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  660. checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  661. }
  662. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_8x8x32) {
  663. using namespace conv_bias;
  664. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  665. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1);
  666. for (auto&& arg : args)
  667. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  668. checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  669. }
  670. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x8) {
  671. using namespace conv_bias;
  672. //! test qint8x8x8
  673. std::vector<TestArg> args =
  674. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2);
  675. for (auto&& arg : args)
  676. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  677. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  678. }
  679. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x32) {
  680. using namespace conv_bias;
  681. //! test qint8x8x8
  682. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  683. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 2);
  684. for (auto&& arg : args)
  685. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  686. checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  687. }
  688. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_8x8x32) {
  689. using namespace conv_bias;
  690. //! test qint8x8x8
  691. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  692. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 2);
  693. for (auto&& arg : args)
  694. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  695. checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  696. }
  697. #endif
  698. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) {
  699. using namespace conv_bias;
  700. std::vector<TestArg> args = get_winograd_args(3);
  701. Checker<ConvBiasForward> checker(handle());
  702. auto run = [&checker](
  703. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  704. DType C_dtype, DType D_dtype, const float eps) {
  705. for (auto&& arg : args) {
  706. checker.set_dtype(0, A_dtype)
  707. .set_dtype(1, B_dtype)
  708. .set_dtype(2, C_dtype)
  709. .set_dtype(4, D_dtype)
  710. .set_epsilon(eps)
  711. .set_param(arg.param)
  712. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  713. }
  714. };
  715. run(args, dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32(),
  716. 1e-3f);
  717. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  718. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  719. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  720. run(args, dtype::Float16(), dtype::Float16(), dtype::Float16(), dtype::Float16(),
  721. 0.35f);
  722. #endif
  723. }
  724. //! uncomment it when low precision mode is ok
  725. #if 0
  726. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) {
  727. using namespace conv_bias;
  728. std::vector<TestArg> args =
  729. get_nchw44_conv_bias_args({3},QUAN_NLMODE,BR_AND_NO_BIASMODE,1);
  730. Checker<ConvBiasForward> checker(handle());
  731. check_winograd("4:7:16", checker, args, param::MatrixMul::Format::MK4,
  732. param::ConvBias::Format::NCHW44);
  733. }
  734. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44_WEIGHT_PREPROCESS) {
  735. using namespace conv_bias;
  736. std::vector<TestArg> args =
  737. get_nchw44_conv_bias_args({3},QUAN_NLMODE,BR_AND_NO_BIASMODE,1);
  738. Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
  739. handle());
  740. check_winograd("4:7:16", checker, args, param::MatrixMul::Format::MK4,
  741. param::ConvBias::Format::NCHW44);
  742. }
  743. #endif
  744. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1) {
  745. using namespace conv_bias;
  746. Checker<ConvBiasForward> checker(handle());
  747. auto run = [&checker](
  748. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  749. DType C_dtype, DType D_dtype, float eps) {
  750. for (auto&& arg : args) {
  751. checker.set_dtype(0, A_dtype)
  752. .set_dtype(1, B_dtype)
  753. .set_dtype(2, C_dtype)
  754. .set_dtype(4, D_dtype)
  755. .set_epsilon(eps)
  756. .set_param(arg.param)
  757. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  758. }
  759. };
  760. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  761. std::vector<TestArg> args_first_half(args.begin(), args.begin() + args.size() / 2);
  762. run(args_first_half, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
  763. dtype::Float32{}, 1e-3f);
  764. }
  765. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2) {
  766. using namespace conv_bias;
  767. Checker<ConvBiasForward> checker(handle());
  768. auto run = [&checker](
  769. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  770. DType C_dtype, DType D_dtype, float eps) {
  771. for (auto&& arg : args) {
  772. checker.set_dtype(0, A_dtype)
  773. .set_dtype(1, B_dtype)
  774. .set_dtype(2, C_dtype)
  775. .set_dtype(4, D_dtype)
  776. .set_epsilon(eps)
  777. .set_param(arg.param)
  778. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  779. }
  780. };
  781. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  782. std::vector<TestArg> args_second_half(args.begin() + args.size() / 2, args.end());
  783. run(args_second_half, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
  784. dtype::Float32{}, 1e-3f);
  785. }
  786. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  787. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F16) {
  788. using namespace conv_bias;
  789. Checker<ConvBiasForward> checker(handle());
  790. auto run = [&checker](
  791. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  792. DType C_dtype, DType D_dtype, float eps) {
  793. for (auto&& arg : args) {
  794. checker.set_dtype(0, A_dtype)
  795. .set_dtype(1, B_dtype)
  796. .set_dtype(2, C_dtype)
  797. .set_dtype(4, D_dtype)
  798. .set_epsilon(eps)
  799. .set_param(arg.param)
  800. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  801. }
  802. };
  803. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  804. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  805. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  806. run(args, dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
  807. 0.25);
  808. }
  809. #endif
  810. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_INT8) {
  811. using namespace conv_bias;
  812. Checker<ConvBiasForward> checker(handle());
  813. auto run = [&checker](
  814. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  815. DType C_dtype, DType D_dtype, float eps) {
  816. for (auto&& arg : args) {
  817. checker.set_dtype(0, A_dtype)
  818. .set_dtype(1, B_dtype)
  819. .set_dtype(2, C_dtype)
  820. .set_dtype(4, D_dtype)
  821. .set_epsilon(eps)
  822. .set_param(arg.param)
  823. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  824. }
  825. };
  826. #if MEGDNN_AARCH64
  827. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  828. #else
  829. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  830. #endif
  831. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  832. ssprintf("WINOGRAD:%s:8:2:32", matmul_name).c_str()));
  833. std::vector<TestArg> quantized_args = get_quantized_winograd_mk_packed_args(8);
  834. UniformIntRNG int_rng{-50, 50};
  835. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  836. run(quantized_args, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
  837. dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), 1e-3);
  838. }
  839. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) {
  840. using namespace conv_bias;
  841. Checker<ConvBiasForward> checker(handle());
  842. auto run = [&checker](
  843. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  844. DType C_dtype, DType D_dtype, float eps) {
  845. for (auto&& arg : args) {
  846. checker.set_dtype(0, A_dtype)
  847. .set_dtype(1, B_dtype)
  848. .set_dtype(2, C_dtype)
  849. .set_dtype(4, D_dtype)
  850. .set_epsilon(eps)
  851. .set_param(arg.param)
  852. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  853. }
  854. };
  855. #if MEGDNN_AARCH64
  856. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  857. #else
  858. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  859. #endif
  860. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  861. ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str()));
  862. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4);
  863. UniformIntRNG int_rng{-50, 50};
  864. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  865. run(quantized_args, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
  866. dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), 1e-3);
  867. }
  868. TEST_F(ARM_COMMON_MULTI_THREADS,
  869. CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32_GROUPMODE) {
  870. using namespace conv_bias;
  871. Checker<ConvBiasForward> checker(handle());
  872. auto run = [&checker](
  873. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  874. DType C_dtype, DType D_dtype, float eps) {
  875. for (auto&& arg : args) {
  876. checker.set_dtype(0, A_dtype)
  877. .set_dtype(1, B_dtype)
  878. .set_dtype(2, C_dtype)
  879. .set_dtype(4, D_dtype)
  880. .set_epsilon(eps)
  881. .set_param(arg.param)
  882. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  883. }
  884. };
  885. float epsilon = 0.001;
  886. #if MEGDNN_AARCH64
  887. const char* matmul_name = "AARCH64_F32_MK4_4x16";
  888. #else
  889. const char* matmul_name = "ARMV7_F32_MK4_4x8";
  890. #endif
  891. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  892. ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str()));
  893. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4, true, true);
  894. UniformIntRNG int_rng{-50, 50};
  895. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  896. run(quantized_args, dtype::QuantizedS8(0.41113496f),
  897. dtype::QuantizedS8(0.01887994f), dtype::QuantizedS32(0.41113496f * 0.01887994f),
  898. dtype::QuantizedS8(0.49550694f), epsilon);
  899. }
  900. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPMODE) {
  901. using namespace conv_bias;
  902. Checker<ConvBiasForward> checker(handle());
  903. auto run = [&checker](
  904. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  905. DType C_dtype, DType D_dtype, float eps) {
  906. for (auto&& arg : args) {
  907. checker.set_dtype(0, A_dtype)
  908. .set_dtype(1, B_dtype)
  909. .set_dtype(2, C_dtype)
  910. .set_dtype(4, D_dtype)
  911. .set_epsilon(eps)
  912. .set_param(arg.param)
  913. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  914. }
  915. };
  916. #if MEGDNN_AARCH64
  917. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  918. #else
  919. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  920. #endif
  921. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  922. ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str()));
  923. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4, false, true);
  924. UniformIntRNG int_rng{-50, 50};
  925. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  926. run(quantized_args, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
  927. dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), 1e-3);
  928. }
  929. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32) {
  930. using namespace conv_bias;
  931. Checker<ConvBiasForward> checker(handle());
  932. auto run = [&checker](
  933. const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  934. DType C_dtype, DType D_dtype, float eps) {
  935. for (auto&& arg : args) {
  936. checker.set_dtype(0, A_dtype)
  937. .set_dtype(1, B_dtype)
  938. .set_dtype(2, C_dtype)
  939. .set_dtype(4, D_dtype)
  940. .set_epsilon(eps)
  941. .set_param(arg.param)
  942. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  943. }
  944. };
  945. float epsilon = 0.001;
  946. #if MEGDNN_AARCH64
  947. const char* matmul_name = "AARCH64_F32_MK4_4x16";
  948. #else
  949. const char* matmul_name = "ARMV7_F32_MK4_4x8";
  950. #endif
  951. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  952. ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str()));
  953. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4, true);
  954. UniformIntRNG int_rng{-50, 50};
  955. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  956. run(quantized_args, dtype::QuantizedS8(0.41113496f),
  957. dtype::QuantizedS8(0.01887994f), dtype::QuantizedS32(0.41113496f * 0.01887994f),
  958. dtype::QuantizedS8(0.49550694f), epsilon);
  959. }
  960. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  961. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F23) {
  962. using namespace conv_bias;
  963. std::vector<TestArg> args = get_winograd_mk_packed_args();
  964. Checker<ConvBiasForward> checker(handle());
  965. check_winograd_fp16("1:2:32", checker, args, NULL, 0.08);
  966. }
  967. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_1) {
  968. using namespace conv_bias;
  969. std::vector<TestArg> args = get_winograd_args(5);
  970. std::vector<TestArg> args_head_half(args.begin(), args.begin() + args.size() / 2);
  971. Checker<ConvBiasForward> checker(handle());
  972. //! fp16 range -1.0 ~ 1.0
  973. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  974. check_winograd_fp16("1:4:32", checker, args_head_half, rng, 0.25);
  975. }
  976. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_2) {
  977. using namespace conv_bias;
  978. std::vector<TestArg> args = get_winograd_args(5);
  979. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2, args.end());
  980. Checker<ConvBiasForward> checker(handle());
  981. //! fp16 range -1.0 ~ 1.0
  982. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  983. check_winograd_fp16("1:4:32", checker, args_back_half, rng, 0.25);
  984. }
  985. //! FIXME: This test may be failed if run `ARM_COMMON.CONV_BIAS_WINOGRAD*`, but
  986. //! it will pass when run single testcase
  987. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F63) {
  988. using namespace conv_bias;
  989. std::vector<TestArg> args = get_winograd_args(3);
  990. Checker<ConvBiasForward> checker(handle());
  991. //! fp16 range -1.0 ~ 1.0
  992. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  993. check_winograd_fp16("1:6:32", checker, args, rng, 0.3);
  994. }
  995. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_1) {
  996. using namespace conv_bias;
  997. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  998. std::vector<TestArg> args_head_half(args.begin(), args.begin() + args.size() / 2);
  999. Checker<ConvBiasForward> checker(handle());
  1000. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1001. check_winograd_fp16(
  1002. "8:2:32", checker, args_head_half, rng, 0.25,
  1003. param::MatrixMul::Format::MK8);
  1004. }
  1005. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_2) {
  1006. using namespace conv_bias;
  1007. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  1008. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2, args.end());
  1009. Checker<ConvBiasForward> checker(handle());
  1010. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1011. check_winograd_fp16(
  1012. "8:2:32", checker, args_back_half, rng, 0.25,
  1013. param::MatrixMul::Format::MK8);
  1014. }
  1015. #endif
  1016. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_INT8_8X8) {
  1017. using namespace conv_bias;
  1018. std::vector<TestArg> args = get_quantized_winograd_mk_packed_args(8);
  1019. Checker<ConvBiasForward> checker(handle());
  1020. UniformIntRNG rng{-50, 50};
  1021. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  1022. .set_dtype(1, dtype::QuantizedS8(2.5f))
  1023. .set_dtype(2, dtype::QuantizedS32(6.25f))
  1024. .set_dtype(4, dtype::QuantizedS8(60.25f))
  1025. .set_rng(0, &rng)
  1026. .set_rng(1, &rng)
  1027. .set_rng(2, &rng);
  1028. check_winograd("8:2:32", checker, args, param::MatrixMul::Format::MK8);
  1029. }
  1030. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_INT8_8X8_WEIGHT_PREPROCESS) {
  1031. using namespace conv_bias;
  1032. std::vector<TestArg> args = get_quantized_winograd_mk_packed_args(8);
  1033. Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
  1034. handle());
  1035. UniformIntRNG rng{-50, 50};
  1036. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  1037. .set_dtype(1, dtype::QuantizedS8(2.5f))
  1038. .set_dtype(2, dtype::QuantizedS32(6.25f))
  1039. .set_dtype(4, dtype::QuantizedS8(60.25f))
  1040. .set_rng(0, &rng)
  1041. .set_rng(1, &rng)
  1042. .set_rng(2, &rng);
  1043. check_winograd("8:2:32", checker, args, param::MatrixMul::Format::MK8);
  1044. }
  1045. // clang-format on
  1046. // vim: syntax=cpp.doxygen