You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bias_multi_thread.cpp 91 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141
  1. /**
  2. * \file dnn/test/arm_common/conv_bias_multi_thread.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "test/arm_common/fixture.h"
  13. #include "test/common/benchmarker.h"
  14. #include "test/common/conv_bias.h"
  15. using namespace megdnn;
  16. using namespace test;
  17. using namespace conv_bias;
  18. std::vector<conv_bias::TestArg> get_int8_quint8_conv_bias_args(
  19. std::vector<size_t> kernel, size_t stride, bool no_pad, bool no_bias,
  20. bool no_nonlinemode) {
  21. using namespace conv_bias;
  22. using Param = param::ConvBias;
  23. using NLMode = param::ConvBias::NonlineMode;
  24. std::vector<TestArg> args;
  25. auto pack = [&](size_t n, size_t oc, size_t ic, size_t w, size_t h,
  26. size_t kernel, size_t stride, NLMode nlmode) {
  27. Param param;
  28. param.stride_h = stride;
  29. param.stride_w = stride;
  30. if (!no_pad) {
  31. param.pad_h = kernel / 2;
  32. param.pad_w = kernel / 2;
  33. } else {
  34. param.pad_h = 0;
  35. param.pad_w = 0;
  36. }
  37. param.nonlineMode = nlmode;
  38. args.emplace_back(param, TensorShape{n, ic, h, w},
  39. TensorShape{oc, ic, kernel, kernel}, TensorShape{});
  40. if (!no_bias) {
  41. args.emplace_back(param, TensorShape{n, ic, h, w},
  42. TensorShape{oc, ic, kernel, kernel},
  43. TensorShape{1, oc, 1, 1});
  44. }
  45. };
  46. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  47. if (!no_nonlinemode) {
  48. nonlinemode.emplace_back(NLMode::RELU);
  49. nonlinemode.emplace_back(NLMode::H_SWISH);
  50. }
  51. for (size_t n : {1, 2}) {
  52. for (auto nlmode : nonlinemode) {
  53. for (size_t ic : {1, 3, 7}) {
  54. for (size_t oc : {1, 3, 7}) {
  55. for (size_t size : {4, 6, 8, 14, 16, 18}) {
  56. for (size_t kern : kernel) {
  57. pack(n, oc, ic, size, size, kern, stride, nlmode);
  58. }
  59. }
  60. }
  61. }
  62. }
  63. }
  64. return args;
  65. }
  66. std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
  67. std::vector<size_t> kernel_vec, size_t stride, bool no_pad = false,
  68. bool no_bias = false, bool no_nonlinemode = false,
  69. bool is_input_nchw = false, bool is_nchw44_dot = false,
  70. bool support_full_bias = false, bool support_sigmoid = false,
  71. bool only_no_bias = false) {
  72. using namespace conv_bias;
  73. using NLMode = param::ConvBias::NonlineMode;
  74. std::vector<TestArg> args;
  75. auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w,
  76. size_t kernel, size_t stride, size_t group, NLMode nlmode,
  77. megdnn::BiasMode bias_mode, int any_pad = -1) {
  78. constexpr int pack_c = 4;
  79. const size_t pad = any_pad >= 0 ? any_pad : kernel / 2;
  80. auto oc_per_group = oc / group;
  81. auto ic_per_group = ic / group;
  82. bool ok_group = (oc % group == 0 && ic % group == 0) &&
  83. oc_per_group % pack_c == 0 && oc_per_group > 0 &&
  84. ic_per_group > 0;
  85. bool nchw_disable = group > 1 || ic_per_group >= 4;
  86. bool nchw44_disable = ic_per_group % pack_c != 0;
  87. bool invalid_pad = (w + 2 * pad < kernel) || (h + 2 * pad < kernel);
  88. if (!(ok_group) || invalid_pad) {
  89. return;
  90. }
  91. if ((is_input_nchw && nchw_disable) ||
  92. (!is_input_nchw && nchw44_disable)) {
  93. return;
  94. }
  95. size_t kernel_h = kernel;
  96. size_t kernel_w = kernel;
  97. param::ConvBias param;
  98. if (!is_nchw44_dot) {
  99. param.format = param::ConvBias::Format::NCHW44;
  100. } else {
  101. param.format = param::ConvBias::Format::NCHW44_DOT;
  102. }
  103. param.stride_h = stride;
  104. param.stride_w = stride;
  105. param.pad_h = pad;
  106. param.pad_w = pad;
  107. param.nonlineMode = nlmode;
  108. auto src_tensor_shape = TensorShape{n, ic / pack_c, h, w, pack_c};
  109. auto weight_tensor_shape = TensorShape{
  110. oc / pack_c, ic / pack_c, kernel_h, kernel_w, pack_c, pack_c};
  111. auto bias_tensor_shape = TensorShape{};
  112. if (bias_mode == megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) {
  113. bias_tensor_shape = {1, oc / pack_c, 1, 1, pack_c};
  114. } else if (bias_mode == megdnn::BiasMode::BIAS) {
  115. bias_tensor_shape = {n, oc / pack_c,
  116. (h + 2 * pad - kernel) / stride + 1,
  117. (w + 2 * pad - kernel) / stride + 1, pack_c};
  118. }
  119. if (group == 1) {
  120. param.sparse = param::ConvBias::Sparse::DENSE;
  121. } else if (group > 1 && ic / group == 1 && oc / group == 1) {
  122. megdnn_assert(0, "not support channel wise");
  123. param.sparse = param::ConvBias::Sparse::GROUP;
  124. weight_tensor_shape = TensorShape{group / pack_c, 1, 1,
  125. kernel_h, kernel_w, pack_c};
  126. } else if (group > 1 && oc_per_group % pack_c == 0 && oc / group > 0 &&
  127. ic_per_group % pack_c == 0 && ic / group > 0) {
  128. param.sparse = param::ConvBias::Sparse::GROUP;
  129. weight_tensor_shape = TensorShape{group,
  130. oc_per_group / pack_c,
  131. ic_per_group / pack_c,
  132. kernel_h,
  133. kernel_w,
  134. pack_c,
  135. pack_c};
  136. }
  137. if (is_input_nchw) {
  138. src_tensor_shape = TensorShape{n, ic, h, w};
  139. weight_tensor_shape =
  140. TensorShape{oc / pack_c, kernel_h, kernel_w, ic, pack_c};
  141. }
  142. args.emplace_back(param, src_tensor_shape, weight_tensor_shape,
  143. bias_tensor_shape);
  144. };
  145. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  146. if (!no_nonlinemode) {
  147. nonlinemode.emplace_back(NLMode::RELU);
  148. nonlinemode.emplace_back(NLMode::H_SWISH);
  149. }
  150. if (support_sigmoid) {
  151. nonlinemode.emplace_back(NLMode::SIGMOID);
  152. }
  153. std::vector<megdnn::BiasMode> bias_mode;
  154. if (!only_no_bias) {
  155. bias_mode.emplace_back(megdnn::BiasMode::BROADCAST_CHANNEL_BIAS);
  156. if (no_bias) {
  157. bias_mode.emplace_back(megdnn::BiasMode::NO_BIAS);
  158. }
  159. } else {
  160. bias_mode.emplace_back(megdnn::BiasMode::NO_BIAS);
  161. }
  162. if (support_full_bias) {
  163. bias_mode.emplace_back(megdnn::BiasMode::BIAS);
  164. }
  165. for (auto bias : bias_mode)
  166. for (auto nlmode : nonlinemode)
  167. for (size_t n : {1, 2})
  168. for (size_t kernel : kernel_vec)
  169. for (size_t oc : {4, 12})
  170. for (size_t ic : {1, 3, 4, 12})
  171. for (size_t h : {3, 5, 12})
  172. for (size_t w : {7, 16, 23}) {
  173. for (size_t group = 1;
  174. group <=
  175. std::min(std::min(oc, ic), 4_z);
  176. ++group) {
  177. pack(n, oc, ic, h, w, kernel, stride,
  178. group, nlmode, bias);
  179. }
  180. }
  181. return args;
  182. }
  183. std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args(
  184. std::vector<size_t> kernel, size_t stride, bool no_bias,
  185. bool no_nonlinemode, bool no_full_bias) {
  186. using namespace conv_bias;
  187. using Param = param::ConvBias;
  188. using NLMode = param::ConvBias::NonlineMode;
  189. std::vector<TestArg> args;
  190. auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
  191. size_t stride, NLMode nlmode, bool pad) {
  192. Param param;
  193. param.stride_h = stride;
  194. param.stride_w = stride;
  195. if (pad) {
  196. param.pad_h = kernel / 2;
  197. param.pad_w = kernel / 2;
  198. } else {
  199. param.pad_h = 0;
  200. param.pad_w = 0;
  201. }
  202. param.nonlineMode = nlmode;
  203. param.format = param::ConvBias::Format::NCHW44;
  204. param.sparse = param::ConvBias::Sparse::GROUP;
  205. args.emplace_back(param, TensorShape{n, group, h, w, 4},
  206. TensorShape{group, 1, 1, kernel, kernel, 4},
  207. TensorShape{});
  208. if (!no_bias) {
  209. args.emplace_back(param, TensorShape{n, group, h, w, 4},
  210. TensorShape{group, 1, 1, kernel, kernel, 4},
  211. TensorShape{1, group, 1, 1, 4});
  212. }
  213. if (!no_full_bias) {
  214. args.emplace_back(
  215. param, TensorShape{n, group, h, w, 4},
  216. TensorShape{group, 1, 1, kernel, kernel, 4},
  217. TensorShape{n, group,
  218. (h + 2 * param.pad_w - kernel) / stride + 1,
  219. (w + 2 * param.pad_w - kernel) / stride + 1,
  220. 4});
  221. }
  222. };
  223. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  224. if (!no_nonlinemode) {
  225. nonlinemode.emplace_back(NLMode::RELU);
  226. nonlinemode.emplace_back(NLMode::H_SWISH);
  227. }
  228. for (size_t n : {1, 2}) {
  229. for (auto nlmode : nonlinemode) {
  230. for (bool pad : {true}) {
  231. for (size_t group : {1, 2, 4, 7, 128}) {
  232. for (size_t size : {4, 6, 7, 9, 15, 40}) {
  233. for (size_t kern : kernel) {
  234. pack(n, group, size, size, kern, stride, nlmode,
  235. pad);
  236. }
  237. }
  238. }
  239. }
  240. for (bool pad : {false}) {
  241. for (size_t group : {1, 2, 7, 128}) {
  242. for (size_t size : {7, 9, 15, 40}) {
  243. for (size_t kern : kernel) {
  244. pack(n, group, size, size, kern, stride, nlmode,
  245. pad);
  246. }
  247. }
  248. }
  249. }
  250. }
  251. }
  252. return args;
  253. }
  254. void checker_conv_bias_qint8x8x8(std::vector<conv_bias::TestArg> args,
  255. Handle* handle, const char* algo_name) {
  256. Checker<ConvBias> checker(handle);
  257. checker.set_before_exec_callback(
  258. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  259. #if MEGDNN_ARMV7
  260. checker.set_epsilon(1);
  261. #endif
  262. UniformIntRNG rng{-50, 50};
  263. checker.set_dtype(0, dtype::QuantizedS8(0.41113496f))
  264. .set_dtype(1, dtype::QuantizedS8(0.01887994f))
  265. .set_dtype(2, dtype::QuantizedS32(0.41113496f * 0.01887994f))
  266. .set_dtype(4, dtype::QuantizedS8(0.49550694f))
  267. .set_rng(0, &rng)
  268. .set_rng(1, &rng)
  269. .set_rng(2, &rng);
  270. for (auto&& arg : args) {
  271. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  272. }
  273. }
  274. void checker_conv_bias_qint8x8x32(std::vector<conv_bias::TestArg> args,
  275. Handle* handle, const char* algo_name) {
  276. Checker<ConvBias> checker(handle);
  277. UniformIntRNG rng{-50, 50};
  278. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  279. .set_dtype(1, dtype::QuantizedS8(2.5f))
  280. .set_dtype(2, dtype::QuantizedS32(6.25f))
  281. .set_dtype(4, {});
  282. checker.set_before_exec_callback(
  283. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  284. for (auto&& arg : args) {
  285. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  286. }
  287. }
  288. void checker_conv_bias_quint8x8x8(std::vector<conv_bias::TestArg> args,
  289. Handle* handle, const char* algo_name) {
  290. Checker<ConvBias> checker(handle);
  291. checker.set_before_exec_callback(
  292. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  293. UniformIntRNG rng(0, 255);
  294. checker.set_dtype(0, dtype::Quantized8Asymm(0.2f, 100))
  295. .set_dtype(1, dtype::Quantized8Asymm(0.2f, 120))
  296. .set_dtype(2, dtype::QuantizedS32(0.04f))
  297. .set_dtype(4, dtype::Quantized8Asymm(1.4f, 110))
  298. .set_rng(0, &rng)
  299. .set_rng(1, &rng)
  300. .set_rng(2, &rng);
  301. for (auto&& arg : args) {
  302. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  303. }
  304. }
  305. void checker_conv_bias_quint8x8x32(std::vector<conv_bias::TestArg> args,
  306. Handle* handle, const char* algo_name) {
  307. Checker<ConvBias> checker(handle);
  308. checker.set_before_exec_callback(
  309. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  310. NormalRNG rng(128.f);
  311. checker.set_rng(0, &rng).set_rng(1, &rng);
  312. checker.set_dtype(0, dtype::Quantized8Asymm(1.2f, (uint8_t)127))
  313. .set_dtype(1, dtype::Quantized8Asymm(1.3f, (uint8_t)129))
  314. .set_dtype(2, dtype::QuantizedS32(1.2 * 1.3))
  315. .set_dtype(4, {});
  316. for (auto&& arg : args) {
  317. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  318. }
  319. }
  320. void checker_conv_bias_int8x8x32_multi(std::vector<conv_bias::TestArg> args,
  321. Handle* handle, const char* algo_name) {
  322. Checker<ConvBias> checker(handle);
  323. checker.set_before_exec_callback(
  324. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  325. checker.set_dtype(0, dtype::Int8());
  326. checker.set_dtype(1, dtype::Int8());
  327. checker.set_dtype(2, dtype::Int32());
  328. checker.set_dtype(4, dtype::Int32());
  329. for (auto&& arg : args) {
  330. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  331. }
  332. }
  333. /**********************************F32 direct************************/
  334. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_LARGE_GROUP) {
  335. check_conv_bias(
  336. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  337. handle(), "F32DIRECT_LARGE_GROUP");
  338. }
  339. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_SMALL_GROUP) {
  340. check_conv_bias(
  341. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  342. handle(), "F32DIRECT_SMALL_GROUP");
  343. }
  344. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) {
  345. check_conv_bias(get_nchw44_conv_bias_args({7}, 1, false, true, true, false,
  346. false, false),
  347. handle(), "F32_CONV_NCHW44_DIRECT");
  348. }
  349. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K2K3) {
  350. check_conv_bias(get_nchw44_conv_bias_args({2, 3}, 1, false, false, false,
  351. false, false, true, true),
  352. handle(), "F32_CONV_NCHW44_DIRECT");
  353. }
  354. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K5) {
  355. check_conv_bias(get_nchw44_conv_bias_args({5}, 1, false, false, false,
  356. false, false, true, true),
  357. handle(), "F32_CONV_NCHW44_DIRECT");
  358. }
  359. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) {
  360. check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false,
  361. false, false, false, true, true),
  362. handle(), "F32_CONV_NCHW44_DIRECT");
  363. }
  364. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_LARGE_GROUP) {
  365. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  366. handle(), "F32STRD1_LARGE_GROUP");
  367. }
  368. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_SMALL_GROUP) {
  369. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  370. handle(), "F32STRD1_SMALL_GROUP");
  371. }
  372. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_LARGE_GROUP) {
  373. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  374. handle(), "F32STRD2_LARGE_GROUP");
  375. }
  376. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) {
  377. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  378. handle(), "F32STRD2_SMALL_GROUP");
  379. }
  380. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S2) {
  381. check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false,
  382. false, true),
  383. handle(), "F32_CONV_NCHW_NCHW44");
  384. }
  385. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S1) {
  386. check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false,
  387. false, true),
  388. handle(), "F32_CONV_NCHW_NCHW44");
  389. }
  390. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_1) {
  391. check_conv_bias(
  392. get_nchw44_channel_wise_args({2, 3}, 1, false, false, false),
  393. handle(), "F32_CHANNEL_WISE_NCHW44");
  394. }
  395. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_2) {
  396. check_conv_bias(get_nchw44_channel_wise_args({5}, 1, false, false, false),
  397. handle(), "F32_CHANNEL_WISE_NCHW44");
  398. }
  399. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP32_NCHW44) {
  400. check_conv_bias(
  401. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, false),
  402. handle(), "F32_CHANNEL_WISE_NCHW44");
  403. }
  404. /**********************************F16 direct************************/
  405. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  406. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_LARGE_GROUP) {
  407. NormalRNG rng(1);
  408. checker_conv_bias_f16(
  409. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  410. handle(), rng, "F16DIRECT_LARGE_GROUP", 0.03);
  411. }
  412. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_SMALL_GROUP) {
  413. NormalRNG rng(1);
  414. checker_conv_bias_f16(
  415. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  416. handle(), rng, "F16DIRECT_SMALL_GROUP", 0.03);
  417. }
  418. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_LARGE_GROUP) {
  419. NormalRNG rng(1);
  420. checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false),
  421. handle(), rng, "F16STRD1_LARGE_GROUP", 0.03);
  422. }
  423. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_SMALL_GROUP) {
  424. NormalRNG rng(1);
  425. checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false),
  426. handle(), rng, "F16STRD1_SMALL_GROUP", 0.03);
  427. }
  428. #endif
  429. /**********************************algo 8816 direct************************/
  430. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_LARGE_GROUP) {
  431. checker_conv_bias_int8x8x16(
  432. get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(),
  433. "I8816DIRECT_LARGE_GROUP");
  434. }
  435. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_SMALL_GROUP) {
  436. checker_conv_bias_int8x8x16(
  437. get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(),
  438. "I8816DIRECT_SMALL_GROUP");
  439. }
  440. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_LARGE_GROUP) {
  441. checker_conv_bias_int8x8x16(
  442. get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
  443. "I8816STRD2_LARGE_GROUP");
  444. }
  445. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_SMALL_GROUP) {
  446. checker_conv_bias_int8x8x16(
  447. get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
  448. "I8816STRD2_SMALL_GROUP");
  449. }
  450. /**********************************algo 8-8-32 direct************************/
  451. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_LARGE_GROUP) {
  452. checker_conv_bias_int8x8x32_multi(
  453. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  454. "S8STRD1_LARGE_GROUP");
  455. }
  456. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_SMALL_GROUP) {
  457. checker_conv_bias_int8x8x32_multi(
  458. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  459. "S8STRD1_SMALL_GROUP");
  460. }
  461. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_LARGE_GROUP) {
  462. checker_conv_bias_int8x8x32_multi(
  463. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  464. "S8STRD2_LARGE_GROUP");
  465. }
  466. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_SMALL_GROUP) {
  467. checker_conv_bias_int8x8x32_multi(
  468. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  469. "S8STRD2_SMALL_GROUP");
  470. }
  471. TEST_F(ARM_COMMON_MULTI_THREADS,
  472. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT1_NCHW44) {
  473. checker_conv_bias_int8x8x32_multi(
  474. get_nchw44_channel_wise_args({2, 3, 5}, 1, false, true, true),
  475. handle(), "S8_CHAN_WISE_STRD1_NCHW44");
  476. }
  477. TEST_F(ARM_COMMON_MULTI_THREADS,
  478. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT2_NCHW44) {
  479. checker_conv_bias_int8x8x32_multi(
  480. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, true, true),
  481. handle(), "S8_CHAN_WISE_STRD2_NCHW44");
  482. }
  483. /********************************qint8 direct******************************/
  484. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_LARGE_GROUP) {
  485. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  486. {2, 3, 5, 7}, 1, false, false, false),
  487. handle(), "S8STRD1_LARGE_GROUP");
  488. }
  489. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_SMALL_GROUP) {
  490. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  491. {2, 3, 5, 7}, 1, false, false, false),
  492. handle(), "S8STRD1_SMALL_GROUP");
  493. }
  494. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_LARGE_GROUP) {
  495. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  496. {2, 3, 5, 7}, 2, false, false, false),
  497. handle(), "S8STRD2_LARGE_GROUP");
  498. }
  499. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_SMALL_GROUP) {
  500. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  501. {2, 3, 5, 7}, 2, false, false, false),
  502. handle(), "S8STRD2_SMALL_GROUP");
  503. }
  504. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) {
  505. checker_conv_bias_qint8x8x8(
  506. get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  507. handle(), "S8_NCHW44_DIRECT");
  508. }
  509. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8832) {
  510. checker_conv_bias_qint8x8x32(
  511. get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, true),
  512. handle(), "S8_NCHW44_DIRECT");
  513. }
  514. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8832) {
  515. checker_conv_bias_qint8x8x32(
  516. get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, true),
  517. handle(), "S8_NCHW44_DIRECT");
  518. }
  519. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) {
  520. checker_conv_bias_qint8x8x8(
  521. get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  522. handle(), "S8_NCHW44_DIRECT");
  523. }
  524. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) {
  525. checker_conv_bias_qint8x8x8(
  526. get_nchw44_channel_wise_args({2, 3, 5}, 1, false, false, true),
  527. handle(), "S8_CHAN_WISE_STRD1_NCHW44");
  528. }
  529. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) {
  530. checker_conv_bias_qint8x8x8(
  531. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, true),
  532. handle(), "S8_CHAN_WISE_STRD2_NCHW44");
  533. }
  534. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S1) {
  535. checker_conv_bias_qint8x8x8(
  536. get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false,
  537. true),
  538. handle(), "S8_CONV_NCHW_NCHW44");
  539. }
  540. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S2) {
  541. checker_conv_bias_qint8x8x8(
  542. get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false,
  543. true),
  544. handle(), "S8_CONV_NCHW_NCHW44");
  545. }
  546. /*****************************quint8 direct****************************/
  547. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_LARGE_GROUP) {
  548. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  549. {2, 3, 5, 7}, 1, false, false, false),
  550. handle(), "QU8STRD1_LARGE_GROUP");
  551. }
  552. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_SMALL_GROUP) {
  553. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  554. {2, 3, 5, 7}, 1, false, false, false),
  555. handle(), "QU8STRD1_SMALL_GROUP");
  556. }
  557. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_LARGE_GROUP) {
  558. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  559. {2, 3, 5, 7}, 2, false, false, false),
  560. handle(), "QU8STRD2_LARGE_GROUP");
  561. }
  562. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) {
  563. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  564. {2, 3, 5, 7}, 2, false, false, false),
  565. handle(), "QU8STRD2_SMALL_GROUP");
  566. }
  567. /****************************dot qint8 direct*************************/
  568. #if __ARM_FEATURE_DOTPROD
  569. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) {
  570. auto args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false,
  571. true);
  572. for (auto&& arg : args) {
  573. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  574. }
  575. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  576. args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false,
  577. true);
  578. for (auto&& arg : args) {
  579. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  580. }
  581. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  582. }
  583. TEST_F(ARM_COMMON_MULTI_THREADS,
  584. CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) {
  585. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  586. {2, 3, 5, 7}, 1, false, false, false),
  587. handle(), "ARMDOTS8STRD1_LARGE_GROUP");
  588. }
  589. TEST_F(ARM_COMMON_MULTI_THREADS,
  590. CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) {
  591. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  592. {2, 3, 5, 7}, 1, false, false, false),
  593. handle(), "ARMDOTS8STRD1_SMALL_GROUP");
  594. }
  595. TEST_F(ARM_COMMON_MULTI_THREADS,
  596. CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) {
  597. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  598. {2, 3, 5, 7}, 2, false, false, false),
  599. handle(), "ARMDOTS8STRD2_LARGE_GROUP");
  600. }
  601. TEST_F(ARM_COMMON_MULTI_THREADS,
  602. CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) {
  603. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  604. {2, 3, 5, 7}, 2, false, false, false),
  605. handle(), "ARMDOTS8STRD2_SMALL_GROUP");
  606. }
  607. /****************************dot 8-8-32 direct*************************/
  608. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_LARGE_GROUP) {
  609. checker_conv_bias_qint8x8x32(
  610. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  611. "ARMDOTS8STRD1_LARGE_GROUP");
  612. }
  613. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_SMALL_GROUP) {
  614. checker_conv_bias_qint8x8x32(
  615. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  616. "ARMDOTS8STRD1_SMALL_GROUP");
  617. }
  618. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_LARGE_GROUP) {
  619. checker_conv_bias_qint8x8x32(
  620. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  621. "ARMDOTS8STRD2_LARGE_GROUP");
  622. }
  623. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_SMALL_GROUP) {
  624. checker_conv_bias_qint8x8x32(
  625. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  626. "ARMDOTS8STRD2_SMALL_GROUP");
  627. }
  628. /******************************dot quint8*****************************/
  629. TEST_F(ARM_COMMON_MULTI_THREADS,
  630. CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) {
  631. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  632. {2, 3, 5, 7}, 1, false, false, false),
  633. handle(), "ARMDOTU8STRD1_LARGE_GROUP");
  634. }
  635. TEST_F(ARM_COMMON_MULTI_THREADS,
  636. CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) {
  637. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  638. {2, 3, 5, 7}, 1, false, false, false),
  639. handle(), "ARMDOTU8STRD1_SMALL_GROUP");
  640. }
  641. TEST_F(ARM_COMMON_MULTI_THREADS,
  642. CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) {
  643. checker_conv_bias_quint8x8x8(
  644. get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false),
  645. handle(), "ARMDOTU8STRD2_LARGE_GROUP");
  646. }
  647. TEST_F(ARM_COMMON_MULTI_THREADS,
  648. CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) {
  649. checker_conv_bias_quint8x8x8(
  650. get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false),
  651. handle(), "ARMDOTU8STRD2_SMALL_GROUP");
  652. }
  653. /******************************dot quint8x8x32***********************/
  654. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_LARGE_GROUP) {
  655. checker_conv_bias_quint8x8x32(
  656. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  657. "ARMDOTU8STRD1_LARGE_GROUP");
  658. }
  659. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_SMALL_GROUP) {
  660. checker_conv_bias_quint8x8x32(
  661. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  662. "ARMDOTU8STRD1_SMALL_GROUP");
  663. }
  664. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_LARGE_GROUP) {
  665. checker_conv_bias_quint8x8x32(
  666. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  667. "ARMDOTU8STRD2_LARGE_GROUP");
  668. }
  669. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_SMALL_GROUP) {
  670. checker_conv_bias_quint8x8x32(
  671. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  672. "ARMDOTU8STRD2_SMALL_GROUP");
  673. }
  674. /******************************dot int8x8x8 nchw44 ***********************/
  675. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x8) {
  676. using namespace conv_bias;
  677. std::vector<TestArg> args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 1);
  678. for (auto&& arg : args)
  679. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  680. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  681. }
  682. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x32) {
  683. using namespace conv_bias;
  684. std::vector<TestArg> args =
  685. get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, true, true);
  686. for (auto&& arg : args)
  687. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  688. checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  689. }
  690. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_8x8x32) {
  691. using namespace conv_bias;
  692. std::vector<TestArg> args =
  693. get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, true, true);
  694. for (auto&& arg : args)
  695. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  696. checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  697. }
  698. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x8) {
  699. using namespace conv_bias;
  700. //! test qint8x8x8
  701. std::vector<TestArg> args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 2);
  702. for (auto&& arg : args)
  703. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  704. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  705. }
  706. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x32) {
  707. using namespace conv_bias;
  708. //! test qint8x8x8
  709. std::vector<TestArg> args =
  710. get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, true, true);
  711. for (auto&& arg : args)
  712. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  713. checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  714. }
  715. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_8x8x32) {
  716. using namespace conv_bias;
  717. //! test qint8x8x8
  718. std::vector<TestArg> args =
  719. get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, true, true);
  720. for (auto&& arg : args)
  721. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  722. checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  723. }
  724. #endif
  725. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4) {
  726. using namespace conv_bias;
  727. std::vector<TestArg> args = get_winograd_mk_packed_args();
  728. Checker<ConvBiasForward> checker(handle());
  729. check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4);
  730. }
  731. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4_NCHW44) {
  732. using namespace conv_bias;
  733. std::vector<TestArg> args = get_nchw44_conv_bias_args({3}, 1);
  734. Checker<ConvBiasForward> checker(handle());
  735. check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4,
  736. param::ConvBias::Format::NCHW44);
  737. }
  738. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63) {
  739. using namespace conv_bias;
  740. std::vector<TestArg> args = get_winograd_args(3);
  741. Checker<ConvBiasForward> checker(handle());
  742. check_winograd("1:6:32", checker, args);
  743. }
  744. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) {
  745. using namespace conv_bias;
  746. std::vector<TestArg> args = get_winograd_mk_packed_args();
  747. Checker<ConvBiasForward> checker(handle());
  748. check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4);
  749. }
  750. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44) {
  751. using namespace conv_bias;
  752. std::vector<TestArg> args = get_nchw44_conv_bias_args({3}, 1);
  753. Checker<ConvBiasForward> checker(handle());
  754. check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4,
  755. param::ConvBias::Format::NCHW44);
  756. }
  757. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54) {
  758. using namespace conv_bias;
  759. std::vector<TestArg> args = get_winograd_args(4);
  760. Checker<ConvBiasForward> checker(handle());
  761. check_winograd("1:5:32", checker, args);
  762. }
  763. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45) {
  764. using namespace conv_bias;
  765. std::vector<TestArg> args = get_winograd_args(5);
  766. Checker<ConvBiasForward> checker(handle());
  767. check_winograd("1:4:32", checker, args);
  768. }
  769. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) {
  770. using namespace conv_bias;
  771. std::vector<TestArg> args = get_winograd_args(3);
  772. Checker<ConvBiasForward> checker(handle());
  773. auto extra_impl = [](const TensorNDArray& tensors, uint32_t m,
  774. param::ConvBias param, Handle* handle) {
  775. megdnn_assert(param.format == param::ConvBias::Format::NCHW);
  776. auto winograd_preprocess_opr =
  777. handle->create_operator<WinogradFilterPreprocess>();
  778. winograd_preprocess_opr->param().output_block_size = m;
  779. TensorLayout filter_transform_layout;
  780. winograd_preprocess_opr->deduce_layout(tensors[1].layout,
  781. filter_transform_layout);
  782. size_t winograd_preprocess_workspace_in_bytes =
  783. winograd_preprocess_opr->get_workspace_in_bytes(
  784. tensors[1].layout, filter_transform_layout);
  785. auto conv_bias_opr = handle->create_operator<ConvBias>();
  786. conv_bias_opr->param() = param;
  787. conv_bias_opr->param().format = param::ConvBias::Format::NCHW_WINOGRAD;
  788. conv_bias_opr->param().output_block_size = m;
  789. size_t conv_bias_workspace_in_bytes =
  790. conv_bias_opr->get_workspace_in_bytes(
  791. tensors[0].layout, filter_transform_layout,
  792. tensors[2].layout, tensors[3].layout, tensors[4].layout,
  793. nullptr);
  794. WorkspaceBundle wb(nullptr, {filter_transform_layout.span().dist_byte(),
  795. conv_bias_workspace_in_bytes,
  796. winograd_preprocess_workspace_in_bytes});
  797. wb.set(malloc(wb.total_size_in_bytes()));
  798. TensorND filter_transform_tensor(wb.get(0),
  799. std::move(filter_transform_layout));
  800. winograd_preprocess_opr->exec(tensors[1], filter_transform_tensor,
  801. wb.get_workspace(2));
  802. conv_bias_opr->exec(tensors[0], filter_transform_tensor, tensors[2],
  803. tensors[3], tensors[4], nullptr,
  804. wb.get_workspace(1));
  805. free(wb.ptr());
  806. };
  807. auto run = [&checker, &extra_impl](
  808. Handle* handle, const std::vector<TestArg>& args,
  809. const std::vector<size_t>& out_size, DType A_dtype,
  810. DType B_dtype, DType C_dtype, DType D_dtype,
  811. const float eps) {
  812. for (auto&& arg : args) {
  813. for (uint32_t m : out_size) {
  814. checker.set_extra_opr_impl(std::bind(extra_impl,
  815. std::placeholders::_1, m,
  816. arg.param, handle));
  817. checker.set_dtype(0, A_dtype)
  818. .set_dtype(1, B_dtype)
  819. .set_dtype(2, C_dtype)
  820. .set_dtype(4, D_dtype)
  821. .set_epsilon(eps)
  822. .set_param(arg.param)
  823. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  824. }
  825. }
  826. };
  827. run(handle(), args, {6}, dtype::Float32(), dtype::Float32(),
  828. dtype::Float32(), dtype::Float32(), 1e-3f);
  829. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  830. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  831. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  832. run(handle(), args, {6}, dtype::Float16(), dtype::Float16(),
  833. dtype::Float16(), dtype::Float16(), 0.35f);
  834. #endif
  835. }
  836. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_PREPROCESS_NCHW44) {
  837. using namespace conv_bias;
  838. std::vector<TestArg> nchw44_args = get_nchw44_conv_bias_args({3}, 1);
  839. Checker<ConvBiasForward> checker(handle());
  840. auto extra_impl = [](const TensorNDArray& tensors, uint32_t m,
  841. param::ConvBias param, Handle* handle) {
  842. megdnn_assert(param.format == param::ConvBias::Format::NCHW44);
  843. auto winograd_preprocess_opr =
  844. handle->create_operator<WinogradFilterPreprocess>();
  845. winograd_preprocess_opr->param().output_block_size = m;
  846. winograd_preprocess_opr->param().format = param::MatrixMul::Format::MK4;
  847. TensorLayout filter_transform_layout;
  848. winograd_preprocess_opr->deduce_layout(tensors[1].layout,
  849. filter_transform_layout);
  850. size_t winograd_preprocess_workspace_in_bytes =
  851. winograd_preprocess_opr->get_workspace_in_bytes(
  852. tensors[1].layout, filter_transform_layout);
  853. auto conv_bias_opr = handle->create_operator<ConvBias>();
  854. conv_bias_opr->param() = param;
  855. conv_bias_opr->param().format =
  856. param::ConvBias::Format::NCHW44_WINOGRAD;
  857. conv_bias_opr->param().output_block_size = m;
  858. size_t conv_bias_workspace_in_bytes =
  859. conv_bias_opr->get_workspace_in_bytes(
  860. tensors[0].layout, filter_transform_layout,
  861. tensors[2].layout, tensors[3].layout, tensors[4].layout,
  862. nullptr);
  863. WorkspaceBundle wb(nullptr, {filter_transform_layout.span().dist_byte(),
  864. conv_bias_workspace_in_bytes,
  865. winograd_preprocess_workspace_in_bytes});
  866. wb.set(malloc(wb.total_size_in_bytes()));
  867. TensorND filter_transform_tensor(wb.get(0),
  868. std::move(filter_transform_layout));
  869. winograd_preprocess_opr->exec(tensors[1], filter_transform_tensor,
  870. wb.get_workspace(2));
  871. conv_bias_opr->exec(tensors[0], filter_transform_tensor, tensors[2],
  872. tensors[3], tensors[4], nullptr,
  873. wb.get_workspace(1));
  874. free(wb.ptr());
  875. };
  876. auto run = [&checker, &extra_impl](
  877. Handle* handle, const std::vector<TestArg>& args,
  878. const std::vector<size_t>& out_size, DType A_dtype,
  879. DType B_dtype, DType C_dtype, DType D_dtype,
  880. const float eps) {
  881. for (auto&& arg : args) {
  882. for (uint32_t m : out_size) {
  883. checker.set_extra_opr_impl(std::bind(extra_impl,
  884. std::placeholders::_1, m,
  885. arg.param, handle));
  886. checker.set_dtype(0, A_dtype)
  887. .set_dtype(1, B_dtype)
  888. .set_dtype(2, C_dtype)
  889. .set_dtype(4, D_dtype)
  890. .set_epsilon(eps)
  891. .set_param(arg.param)
  892. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  893. }
  894. }
  895. };
  896. run(handle(), nchw44_args, {2, 6}, dtype::Float32(), dtype::Float32(),
  897. dtype::Float32(), dtype::Float32(), 1e-3f);
  898. }
  899. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1) {
  900. using namespace conv_bias;
  901. Checker<ConvBiasForward> checker(handle());
  902. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  903. const std::vector<size_t>& out_size, DType A_dtype,
  904. DType B_dtype, DType C_dtype, DType D_dtype,
  905. param::MatrixMul::Format format, float eps) {
  906. for (auto&& arg : args) {
  907. for (uint32_t m : out_size) {
  908. checker.set_extra_opr_impl(std::bind(
  909. winograd_algo_extra_impl, std::placeholders::_1, m,
  910. arg.param, handle, format));
  911. checker.set_dtype(0, A_dtype)
  912. .set_dtype(1, B_dtype)
  913. .set_dtype(2, C_dtype)
  914. .set_dtype(4, D_dtype)
  915. .set_epsilon(eps)
  916. .set_param(arg.param)
  917. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  918. }
  919. }
  920. };
  921. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  922. std::vector<TestArg> args_first_half(args.begin(),
  923. args.begin() + args.size() / 2);
  924. run(handle(), args_first_half, {2, 6}, dtype::Float32{}, dtype::Float32{},
  925. dtype::Float32{}, dtype::Float32{}, param::MatrixMul::Format::MK4,
  926. 1e-3f);
  927. }
  928. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2) {
  929. using namespace conv_bias;
  930. Checker<ConvBiasForward> checker(handle());
  931. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  932. const std::vector<size_t>& out_size, DType A_dtype,
  933. DType B_dtype, DType C_dtype, DType D_dtype,
  934. param::MatrixMul::Format format, float eps) {
  935. for (auto&& arg : args) {
  936. for (uint32_t m : out_size) {
  937. checker.set_extra_opr_impl(std::bind(
  938. winograd_algo_extra_impl, std::placeholders::_1, m,
  939. arg.param, handle, format));
  940. checker.set_dtype(0, A_dtype)
  941. .set_dtype(1, B_dtype)
  942. .set_dtype(2, C_dtype)
  943. .set_dtype(4, D_dtype)
  944. .set_epsilon(eps)
  945. .set_param(arg.param)
  946. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  947. }
  948. }
  949. };
  950. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  951. std::vector<TestArg> args_second_half(args.begin() + args.size() / 2,
  952. args.end());
  953. run(handle(), args_second_half, {2, 6}, dtype::Float32{}, dtype::Float32{},
  954. dtype::Float32{}, dtype::Float32{}, param::MatrixMul::Format::MK4,
  955. 1e-3f);
  956. }
  957. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  958. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F16) {
  959. using namespace conv_bias;
  960. Checker<ConvBiasForward> checker(handle());
  961. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  962. const std::vector<size_t>& out_size, DType A_dtype,
  963. DType B_dtype, DType C_dtype, DType D_dtype,
  964. param::MatrixMul::Format format, float eps) {
  965. for (auto&& arg : args) {
  966. for (uint32_t m : out_size) {
  967. checker.set_extra_opr_impl(std::bind(
  968. winograd_algo_extra_impl, std::placeholders::_1, m,
  969. arg.param, handle, format));
  970. checker.set_dtype(0, A_dtype)
  971. .set_dtype(1, B_dtype)
  972. .set_dtype(2, C_dtype)
  973. .set_dtype(4, D_dtype)
  974. .set_epsilon(eps)
  975. .set_param(arg.param)
  976. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  977. }
  978. }
  979. };
  980. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  981. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  982. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  983. run(handle(), args, {2}, dtype::Float16{}, dtype::Float16{},
  984. dtype::Float16{}, dtype::Float16{}, param::MatrixMul::Format::MK8,
  985. 0.25);
  986. }
  987. #endif
  988. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_INT8) {
  989. using namespace conv_bias;
  990. Checker<ConvBiasForward> checker(handle());
  991. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  992. const std::vector<size_t>& out_size, DType A_dtype,
  993. DType B_dtype, DType C_dtype, DType D_dtype,
  994. param::MatrixMul::Format format, float eps) {
  995. for (auto&& arg : args) {
  996. for (uint32_t m : out_size) {
  997. checker.set_extra_opr_impl(std::bind(
  998. winograd_algo_extra_impl, std::placeholders::_1, m,
  999. arg.param, handle, format));
  1000. checker.set_dtype(0, A_dtype)
  1001. .set_dtype(1, B_dtype)
  1002. .set_dtype(2, C_dtype)
  1003. .set_dtype(4, D_dtype)
  1004. .set_epsilon(eps)
  1005. .set_param(arg.param)
  1006. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  1007. }
  1008. }
  1009. };
  1010. #if MEGDNN_AARCH64
  1011. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  1012. #else
  1013. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  1014. #endif
  1015. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  1016. ssprintf("WINOGRAD:%s:8:2:32", matmul_name).c_str()));
  1017. std::vector<TestArg> quantized_args =
  1018. get_quantized_winograd_mk_packed_args(8);
  1019. UniformIntRNG int_rng{-50, 50};
  1020. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  1021. run(handle(), quantized_args, {2}, dtype::QuantizedS8(2.5f),
  1022. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f),
  1023. dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3);
  1024. }
  1025. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) {
  1026. using namespace conv_bias;
  1027. Checker<ConvBiasForward> checker(handle());
  1028. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  1029. const std::vector<size_t>& out_size, DType A_dtype,
  1030. DType B_dtype, DType C_dtype, DType D_dtype,
  1031. param::MatrixMul::Format format, float eps) {
  1032. for (auto&& arg : args) {
  1033. for (uint32_t m : out_size) {
  1034. checker.set_extra_opr_impl(std::bind(
  1035. winograd_algo_extra_impl, std::placeholders::_1, m,
  1036. arg.param, handle, format));
  1037. checker.set_dtype(0, A_dtype)
  1038. .set_dtype(1, B_dtype)
  1039. .set_dtype(2, C_dtype)
  1040. .set_dtype(4, D_dtype)
  1041. .set_epsilon(eps)
  1042. .set_param(arg.param)
  1043. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  1044. }
  1045. }
  1046. };
  1047. #if MEGDNN_AARCH64
  1048. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  1049. #else
  1050. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  1051. #endif
  1052. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  1053. ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str()));
  1054. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4);
  1055. UniformIntRNG int_rng{-50, 50};
  1056. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  1057. run(handle(), quantized_args, {2}, dtype::QuantizedS8(2.5f),
  1058. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f),
  1059. dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3);
  1060. }
  1061. TEST_F(ARM_COMMON_MULTI_THREADS,
  1062. CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPMODE) {
  1063. using namespace conv_bias;
  1064. Checker<ConvBiasForward> checker(handle());
  1065. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  1066. const std::vector<size_t>& out_size, DType A_dtype,
  1067. DType B_dtype, DType C_dtype, DType D_dtype,
  1068. param::MatrixMul::Format format, float eps) {
  1069. for (auto&& arg : args) {
  1070. for (uint32_t m : out_size) {
  1071. checker.set_extra_opr_impl(std::bind(
  1072. winograd_algo_extra_impl, std::placeholders::_1, m,
  1073. arg.param, handle, format));
  1074. checker.set_dtype(0, A_dtype)
  1075. .set_dtype(1, B_dtype)
  1076. .set_dtype(2, C_dtype)
  1077. .set_dtype(4, D_dtype)
  1078. .set_epsilon(eps)
  1079. .set_param(arg.param)
  1080. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  1081. }
  1082. }
  1083. };
  1084. #if MEGDNN_AARCH64
  1085. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  1086. #else
  1087. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  1088. #endif
  1089. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  1090. ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str()));
  1091. std::vector<TestArg> quantized_args =
  1092. get_int8_nchw44_args(3, 4, false, true);
  1093. UniformIntRNG int_rng{-50, 50};
  1094. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  1095. run(handle(), quantized_args, {2}, dtype::QuantizedS8(2.5f),
  1096. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f),
  1097. dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3);
  1098. }
  1099. TEST_F(ARM_COMMON_MULTI_THREADS,
  1100. CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32) {
  1101. using namespace conv_bias;
  1102. Checker<ConvBiasForward> checker(handle());
  1103. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  1104. const std::vector<size_t>& out_size, DType A_dtype,
  1105. DType B_dtype, DType C_dtype, DType D_dtype,
  1106. param::MatrixMul::Format format, float eps) {
  1107. for (auto&& arg : args) {
  1108. for (uint32_t m : out_size) {
  1109. checker.set_extra_opr_impl(std::bind(
  1110. winograd_algo_extra_impl, std::placeholders::_1, m,
  1111. arg.param, handle, format));
  1112. checker.set_dtype(0, A_dtype)
  1113. .set_dtype(1, B_dtype)
  1114. .set_dtype(2, C_dtype)
  1115. .set_dtype(4, D_dtype)
  1116. .set_epsilon(eps)
  1117. .set_param(arg.param)
  1118. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  1119. }
  1120. }
  1121. };
  1122. float epsilon = 0.001;
  1123. #if MEGDNN_AARCH64
  1124. const char* matmul_name = "AARCH64_F32_MK4_4x16";
  1125. #else
  1126. const char* matmul_name = "ARMV7_F32_MK4_4x8";
  1127. #endif
  1128. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  1129. ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str()));
  1130. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4, true);
  1131. UniformIntRNG int_rng{-50, 50};
  1132. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  1133. run(handle(), quantized_args, {2}, dtype::QuantizedS8(0.41113496f),
  1134. dtype::QuantizedS8(0.01887994f),
  1135. dtype::QuantizedS32(0.41113496f * 0.01887994f),
  1136. dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4,
  1137. epsilon);
  1138. }
  1139. TEST_F(ARM_COMMON_MULTI_THREADS,
  1140. CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32_GROUPMODE) {
  1141. using namespace conv_bias;
  1142. Checker<ConvBiasForward> checker(handle());
  1143. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  1144. const std::vector<size_t>& out_size, DType A_dtype,
  1145. DType B_dtype, DType C_dtype, DType D_dtype,
  1146. param::MatrixMul::Format format, float eps) {
  1147. for (auto&& arg : args) {
  1148. for (uint32_t m : out_size) {
  1149. checker.set_extra_opr_impl(std::bind(
  1150. winograd_algo_extra_impl, std::placeholders::_1, m,
  1151. arg.param, handle, format));
  1152. checker.set_dtype(0, A_dtype)
  1153. .set_dtype(1, B_dtype)
  1154. .set_dtype(2, C_dtype)
  1155. .set_dtype(4, D_dtype)
  1156. .set_epsilon(eps)
  1157. .set_param(arg.param)
  1158. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  1159. }
  1160. }
  1161. };
  1162. float epsilon = 0.001;
  1163. #if MEGDNN_AARCH64
  1164. const char* matmul_name = "AARCH64_F32_MK4_4x16";
  1165. #else
  1166. const char* matmul_name = "ARMV7_F32_MK4_4x8";
  1167. #endif
  1168. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  1169. ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str()));
  1170. std::vector<TestArg> quantized_args =
  1171. get_int8_nchw44_args(3, 4, true, true);
  1172. UniformIntRNG int_rng{-50, 50};
  1173. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  1174. run(handle(), quantized_args, {2}, dtype::QuantizedS8(0.41113496f),
  1175. dtype::QuantizedS8(0.01887994f),
  1176. dtype::QuantizedS32(0.41113496f * 0.01887994f),
  1177. dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4,
  1178. epsilon);
  1179. }
  1180. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  1181. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F23) {
  1182. using namespace conv_bias;
  1183. std::vector<TestArg> args = get_winograd_mk_packed_args();
  1184. Checker<ConvBiasForward> checker(handle());
  1185. check_winograd_fp16("1:2:32", checker, args, NULL, 0.08);
  1186. }
  1187. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_1) {
  1188. using namespace conv_bias;
  1189. std::vector<TestArg> args = get_winograd_args(5);
  1190. std::vector<TestArg> args_head_half(args.begin(),
  1191. args.begin() + args.size() / 2);
  1192. Checker<ConvBiasForward> checker(handle());
  1193. //! fp16 range -1.0 ~ 1.0
  1194. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1195. check_winograd_fp16("1:4:32", checker, args_head_half, rng, 0.25);
  1196. }
  1197. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_2) {
  1198. using namespace conv_bias;
  1199. std::vector<TestArg> args = get_winograd_args(5);
  1200. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2,
  1201. args.end());
  1202. Checker<ConvBiasForward> checker(handle());
  1203. //! fp16 range -1.0 ~ 1.0
  1204. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1205. check_winograd_fp16("1:4:32", checker, args_back_half, rng, 0.25);
  1206. }
  1207. //! FIXME: This test may be failed if run `ARM_COMMON.CONV_BIAS_WINOGRAD*`, but
  1208. //! it will pass when run single testcase
  1209. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F63) {
  1210. using namespace conv_bias;
  1211. std::vector<TestArg> args = get_winograd_args(3);
  1212. Checker<ConvBiasForward> checker(handle());
  1213. //! fp16 range -1.0 ~ 1.0
  1214. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1215. check_winograd_fp16("1:6:32", checker, args, rng, 0.3);
  1216. }
  1217. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_1) {
  1218. using namespace conv_bias;
  1219. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  1220. std::vector<TestArg> args_head_half(args.begin(),
  1221. args.begin() + args.size() / 2);
  1222. Checker<ConvBiasForward> checker(handle());
  1223. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1224. check_winograd_fp16("8:2:32", checker, args_head_half, rng, 0.25,
  1225. param::MatrixMul::Format::MK8);
  1226. }
  1227. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_2) {
  1228. using namespace conv_bias;
  1229. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  1230. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2,
  1231. args.end());
  1232. Checker<ConvBiasForward> checker(handle());
  1233. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1234. check_winograd_fp16("8:2:32", checker, args_back_half, rng, 0.25,
  1235. param::MatrixMul::Format::MK8);
  1236. }
  1237. #endif
  1238. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_INT8_8X8) {
  1239. using namespace conv_bias;
  1240. std::vector<TestArg> args = get_quantized_winograd_mk_packed_args(8);
  1241. Checker<ConvBiasForward> checker(handle());
  1242. UniformIntRNG rng{-50, 50};
  1243. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  1244. .set_dtype(1, dtype::QuantizedS8(2.5f))
  1245. .set_dtype(2, dtype::QuantizedS32(6.25f))
  1246. .set_dtype(4, dtype::QuantizedS8(60.25f))
  1247. .set_rng(0, &rng)
  1248. .set_rng(1, &rng)
  1249. .set_rng(2, &rng);
  1250. check_winograd("8:2:32", checker, args, param::MatrixMul::Format::MK8);
  1251. }
  1252. void checker_conv_bias(std::vector<conv_bias::TestArg> args, Handle* handle,
  1253. RNG* rng, float epsilon, DType type0, DType type1,
  1254. DType type2, DType type3, const char* algo_name) {
  1255. using namespace conv_bias;
  1256. Checker<ConvBias> checker(handle);
  1257. checker.set_before_exec_callback(
  1258. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  1259. checker.set_dtype(0, type0);
  1260. checker.set_dtype(1, type1);
  1261. checker.set_dtype(2, type2);
  1262. checker.set_dtype(4, type3);
  1263. checker.set_epsilon(epsilon);
  1264. if (NULL != rng) {
  1265. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng).set_rng(3, rng);
  1266. }
  1267. for (auto&& arg : args) {
  1268. checker.set_param(arg.param).execs(
  1269. {arg.src, arg.filter, arg.bias, {}, {}});
  1270. }
  1271. }
  1272. // clang-format off
  1273. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE2) {
  1274. #define cb(name) \
  1275. check_conv_bias( \
  1276. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 2, false, false, false), \
  1277. handle(), name);
  1278. #if MEGDNN_AARCH64
  1279. cb("IM2COLMATMUL:AARCH64_F32K8X12X1")
  1280. cb("IM2COLMATMUL:AARCH64_F32K4X16X1")
  1281. cb("IM2COLMATMUL:FB_F32_K8X12X1")
  1282. #elif MEGDNN_ARMV7
  1283. cb("IM2COLMATMUL:ARMV7_F32")
  1284. #endif
  1285. #undef cb
  1286. }
  1287. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE1) {
  1288. #define cb(name) \
  1289. check_conv_bias( \
  1290. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, false), \
  1291. handle(), name);
  1292. #if MEGDNN_AARCH64
  1293. cb("IM2COLMATMUL:AARCH64_F32K8X12X1")
  1294. cb("IM2COLMATMUL:AARCH64_F32K4X16X1")
  1295. cb("IM2COLMATMUL:FB_F32_K8X12X1")
  1296. #elif MEGDNN_ARMV7
  1297. cb("IM2COLMATMUL:ARMV7_F32")
  1298. cb("IM2COLMATMUL:FB_F32_K8X12X1")
  1299. #endif
  1300. #undef cb
  1301. }
  1302. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) {
  1303. UniformIntRNG rng{-50, 50};
  1304. #define cb(name) \
  1305. checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
  1306. false, true, true), \
  1307. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1308. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1309. dtype::QuantizedS8(60.25f), name); \
  1310. checker_conv_bias( \
  1311. get_conv_bias_args({1}, 2, false, false, false, true, true), \
  1312. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1313. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1314. dtype::QuantizedS8(60.25f), name);
  1315. float epsilon = 0.001;
  1316. #if MEGDNN_AARCH64
  1317. #if __ARM_FEATURE_DOTPROD
  1318. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD");
  1319. #else
  1320. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8");
  1321. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16");
  1322. #endif
  1323. #elif MEGDNN_ARMV7
  1324. epsilon = 1;
  1325. cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8");
  1326. #endif
  1327. #undef cb
  1328. }
  1329. #if __ARM_FEATURE_DOTPROD
  1330. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT) {
  1331. UniformIntRNG rng{-50, 50};
  1332. #define cb(name) \
  1333. checker_conv_bias(get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, \
  1334. false, false, false, true), \
  1335. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1336. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1337. dtype::QuantizedS8(60.25f), name); \
  1338. checker_conv_bias( \
  1339. get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true), \
  1340. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1341. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1342. dtype::QuantizedS8(60.25f), name);
  1343. float epsilon = 0.001;
  1344. #if MEGDNN_AARCH64
  1345. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96");
  1346. #elif MEGDNN_ARMV7
  1347. cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96");
  1348. #endif
  1349. #undef cb
  1350. }
  1351. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT_S2_FUSE) {
  1352. UniformIntRNG rng{-50, 50};
  1353. #define cb(name) \
  1354. checker_conv_bias(get_nchw44_conv_bias_args({3}, 2, false, \
  1355. false, false, false, true), \
  1356. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1357. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1358. dtype::QuantizedS8(60.25f), name); \
  1359. float epsilon = 0.001;
  1360. #if MEGDNN_AARCH64
  1361. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96");
  1362. #elif MEGDNN_ARMV7
  1363. cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96");
  1364. #endif
  1365. #undef cb
  1366. }
  1367. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_S8x8x32_MK4_DOT) {
  1368. UniformIntRNG rng{-50, 50};
  1369. #define cb(name) \
  1370. checker_conv_bias( \
  1371. get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
  1372. true, false, true, false, false, true), \
  1373. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1374. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \
  1375. checker_conv_bias( \
  1376. get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \
  1377. false, false, true), \
  1378. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1379. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name);
  1380. float epsilon = 0.001;
  1381. #if MEGDNN_AARCH64
  1382. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96");
  1383. #elif MEGDNN_ARMV7
  1384. cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96");
  1385. #endif
  1386. #undef cb
  1387. }
  1388. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32_MK4_DOT) {
  1389. UniformIntRNG rng{-50, 50};
  1390. #define cb(name) \
  1391. checker_conv_bias( \
  1392. get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
  1393. true, false, true, false, false, true), \
  1394. handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \
  1395. dtype::Int32(), {}, name); \
  1396. checker_conv_bias( \
  1397. get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \
  1398. false, false, true), \
  1399. handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \
  1400. dtype::Int32(), {}, name);
  1401. float epsilon = 0.001;
  1402. #if MEGDNN_AARCH64
  1403. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96");
  1404. #elif MEGDNN_ARMV7
  1405. cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96");
  1406. #endif
  1407. #undef cb
  1408. }
  1409. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CONV1x1_QUANTIZEDSYM_MK4_DOT) {
  1410. UniformIntRNG rng{-50, 50};
  1411. #define cb(name) \
  1412. checker_conv_bias( \
  1413. get_nchw44_conv_bias_args({1}, 1, true, true, false, false, true), \
  1414. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1415. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1416. dtype::QuantizedS8(60.25f), name); \
  1417. checker_conv_bias( \
  1418. get_nchw44_conv_bias_args({1}, 1, true, true, true, false, true, \
  1419. false, false, true), \
  1420. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1421. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \
  1422. checker_conv_bias( \
  1423. get_nchw44_conv_bias_args({1}, 1, true, true, true, false, true, \
  1424. false, false, true), \
  1425. handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \
  1426. dtype::Int32(), {}, name);
  1427. float epsilon = 0.001;
  1428. #if MEGDNN_AARCH64
  1429. cb("CONV1x1:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD");
  1430. #elif MEGDNN_ARMV7
  1431. cb("CONV1x1:AARCH32_INT8_MK4_8X4X4_DOTPROD");
  1432. #endif
  1433. #undef cb
  1434. }
  1435. #endif
  1436. // clang-format on
  1437. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1438. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) {
  1439. NormalRNG rng(128.f);
  1440. #define cb(name) \
  1441. checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
  1442. false, true, true), \
  1443. handle(), &rng, epsilon, \
  1444. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1445. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1446. dtype::QuantizedS32(1.2 * 1.3), \
  1447. dtype::Quantized8Asymm(50.3f, (uint8_t)120), name); \
  1448. checker_conv_bias( \
  1449. get_conv_bias_args({1}, 2, false, false, false, true, true), \
  1450. handle(), &rng, epsilon, \
  1451. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1452. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1453. dtype::QuantizedS32(1.2 * 1.3), \
  1454. dtype::Quantized8Asymm(50.3f, (uint8_t)120), name);
  1455. float epsilon = 0.001;
  1456. #if MEGDNN_AARCH64
  1457. #if __ARM_FEATURE_DOTPROD
  1458. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD");
  1459. #else
  1460. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8");
  1461. #endif
  1462. #elif MEGDNN_ARMV7
  1463. epsilon = 1;
  1464. cb("IM2COLMATMUL:ARMV7_QUINT8_K4X8X8");
  1465. #endif
  1466. #undef cb
  1467. }
  1468. #endif
  1469. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1470. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) {
  1471. UniformIntRNG rng{-50, 50};
  1472. float epsilon = 0.001;
  1473. #define cb(name) \
  1474. checker_conv_bias( \
  1475. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \
  1476. handle(), &rng, epsilon, \
  1477. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1478. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1479. dtype::QuantizedS32(1.2 * 1.3), {}, name); \
  1480. checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \
  1481. &rng, epsilon, \
  1482. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1483. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1484. dtype::QuantizedS32(1.2 * 1.3), {}, name);
  1485. #if MEGDNN_AARCH64
  1486. #if __ARM_FEATURE_DOTPROD
  1487. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD");
  1488. #else
  1489. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8");
  1490. #endif
  1491. #elif MEGDNN_ARMV7
  1492. #if __ARM_FEATURE_DOTPROD
  1493. cb("IM2COLMATMUL:AARCH32_QUINT8_K4X8X4");
  1494. #endif
  1495. cb("IM2COLMATMUL:ARMV7_QUINT8_K4X8X8");
  1496. #endif
  1497. #undef cb
  1498. }
  1499. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) {
  1500. UniformIntRNG rng{-50, 50};
  1501. float epsilon = 0.001;
  1502. #define cb(name) \
  1503. checker_conv_bias( \
  1504. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \
  1505. handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
  1506. dtype::Int16{}, dtype::Int16{}, name); \
  1507. checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \
  1508. &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
  1509. dtype::Int16{}, dtype::Int16{}, name);
  1510. #if MEGDNN_AARCH64
  1511. cb("IM2COLMATMUL:AARCH64_INT8X8X16_K8X8X8");
  1512. cb("IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16");
  1513. cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16");
  1514. #elif MEGDNN_ARMV7
  1515. cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16");
  1516. cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8");
  1517. cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X2X16");
  1518. #endif
  1519. #undef cb
  1520. }
  1521. #endif
  1522. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  1523. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP16) {
  1524. using namespace conv_bias;
  1525. param::ConvBias cur_param;
  1526. std::vector<conv_bias::TestArg> args =
  1527. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, false);
  1528. std::vector<conv_bias::TestArg> args1 =
  1529. get_conv_bias_args({1}, 2, false, false, false);
  1530. args.insert(args.begin(), args1.begin(), args1.end());
  1531. NormalRNG rng(1);
  1532. #define cb(name) \
  1533. checker_conv_bias(args, handle(), &rng, 0.03, dtype::Float16{}, \
  1534. dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, \
  1535. name);
  1536. #if MEGDNN_AARCH64
  1537. cb("IM2COLMATMUL:AARCH64_F16_K8X24X1");
  1538. #elif MEGDNN_ARMV7
  1539. cb("IM2COLMATMUL:AARCH32_F16_K4X16X1");
  1540. #endif
  1541. #undef cb
  1542. }
  1543. #endif
  1544. void checker_conv_bias_mul_int8x8x32(std::vector<conv_bias::TestArg> args,
  1545. Handle* handle, const char* algo_name) {
  1546. using namespace conv_bias;
  1547. Checker<ConvBias> checker(handle);
  1548. checker.set_before_exec_callback(
  1549. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  1550. checker.set_dtype(0, dtype::Int8());
  1551. checker.set_dtype(1, dtype::Int8());
  1552. checker.set_dtype(2, dtype::Int32());
  1553. checker.set_dtype(4, dtype::Int32());
  1554. for (auto&& arg : args) {
  1555. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  1556. }
  1557. UniformIntRNG rng{-50, 50};
  1558. for (auto&& arg : args) {
  1559. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  1560. .set_dtype(1, dtype::QuantizedS8(2.5f))
  1561. .set_dtype(2, dtype::QuantizedS32(6.25f))
  1562. .set_dtype(4, {})
  1563. .set_rng(0, &rng)
  1564. .set_rng(1, &rng)
  1565. .set_rng(2, &rng)
  1566. .set_param(arg.param)
  1567. .execs({arg.src, arg.filter, {}, {}, {}});
  1568. }
  1569. }
  1570. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1571. #if !__ARM_FEATURE_DOTPROD
  1572. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) {
  1573. using namespace conv_bias;
  1574. std::vector<conv_bias::TestArg> args =
  1575. get_nchw44_conv_bias_args({2, 5, 7}, 2, false, true, true);
  1576. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1577. #if MEGDNN_AARCH64
  1578. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1579. #else
  1580. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1581. #endif
  1582. #undef cb
  1583. }
  1584. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1) {
  1585. using namespace conv_bias;
  1586. std::vector<conv_bias::TestArg> args =
  1587. get_nchw44_conv_bias_args({3, 4, 6}, 1, false, true, true);
  1588. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1589. #if MEGDNN_AARCH64
  1590. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1591. #else
  1592. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1593. #endif
  1594. #undef cb
  1595. }
  1596. TEST_F(ARM_COMMON_MULTI_THREADS,
  1597. CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_S2) {
  1598. UniformIntRNG rng{-50, 50};
  1599. #define cb(name) \
  1600. checker_conv_bias(get_nchw44_conv_bias_args({3, 4, 6}, 2), handle(), &rng, \
  1601. epsilon, dtype::QuantizedS8(2.5f), \
  1602. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1603. dtype::QuantizedS8(60.25f), name);
  1604. float epsilon = 0.001;
  1605. #if MEGDNN_AARCH64
  1606. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1607. #else
  1608. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1609. #endif
  1610. #undef cb
  1611. }
  1612. TEST_F(ARM_COMMON_MULTI_THREADS,
  1613. CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_S1) {
  1614. UniformIntRNG rng{-50, 50};
  1615. #define cb(name) \
  1616. checker_conv_bias(get_nchw44_conv_bias_args({2, 5, 7}, 1), handle(), &rng, \
  1617. epsilon, dtype::QuantizedS8(2.5f), \
  1618. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1619. dtype::QuantizedS8(60.25f), name);
  1620. float epsilon = 0.001;
  1621. #if MEGDNN_AARCH64
  1622. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1623. #else
  1624. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1625. #endif
  1626. #undef cb
  1627. }
  1628. #if MEGDNN_AARCH64
  1629. TEST_F(ARM_COMMON_MULTI_THREADS,
  1630. CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_FUSE) {
  1631. UniformIntRNG rng{-50, 50};
  1632. #define cb(name) \
  1633. checker_conv_bias(get_nchw44_conv_bias_args({3}, 1), handle(), &rng, \
  1634. epsilon, dtype::QuantizedS8(2.5f), \
  1635. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1636. dtype::QuantizedS8(60.25f), name);
  1637. float epsilon = 0.001;
  1638. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1639. #undef cb
  1640. }
  1641. #endif
  1642. #endif
  1643. #endif
  1644. #if MEGDNN_AARCH64
  1645. #if __ARM_FEATURE_DOTPROD
  1646. TEST_F(ARM_COMMON_MULTI_THREADS,
  1647. CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44DOT_FUSE) {
  1648. UniformIntRNG rng{-50, 50};
  1649. #define cb(name) \
  1650. checker_conv_bias( \
  1651. get_nchw44_conv_bias_args({3}, 1, false, false, false, false, \
  1652. true, false, false, false), \
  1653. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1654. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1655. dtype::QuantizedS8(60.25f), name);
  1656. float epsilon = 0.001;
  1657. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96");
  1658. #undef cb
  1659. }
  1660. #endif
  1661. #endif
  1662. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
  1663. using namespace conv_bias;
  1664. std::vector<conv_bias::TestArg> args =
  1665. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true);
  1666. std::vector<conv_bias::TestArg> args1 =
  1667. get_conv_bias_args({1}, 2, false, true, true);
  1668. args.insert(args.begin(), args1.begin(), args1.end());
  1669. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1670. #if MEGDNN_AARCH64
  1671. #if __ARM_FEATURE_DOTPROD
  1672. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD");
  1673. #else
  1674. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8");
  1675. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16");
  1676. #endif
  1677. #elif MEGDNN_ARMV7
  1678. #if __ARM_FEATURE_DOTPROD
  1679. cb("IM2COLMATMUL:AARCH32_INT8_K6X8X4");
  1680. #endif
  1681. cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8");
  1682. #endif
  1683. #if MEGDNN_ARMV7
  1684. cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X2X16");
  1685. #endif
  1686. #undef cb
  1687. }
  1688. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) {
  1689. using namespace conv_bias;
  1690. std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args(
  1691. {2, 4, 7}, 1, false, false, false, false, false, true, true);
  1692. #if MEGDNN_AARCH64
  1693. check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
  1694. #elif MEGDNN_ARMV7
  1695. check_conv_bias(args, handle(), "IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12");
  1696. #endif
  1697. }
  1698. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) {
  1699. using namespace conv_bias;
  1700. std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args(
  1701. {3, 5, 6}, 2, false, false, false, false, false, true, true);
  1702. #if MEGDNN_AARCH64
  1703. check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
  1704. #elif MEGDNN_ARMV7
  1705. check_conv_bias(args, handle(), "IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12");
  1706. #endif
  1707. }
  1708. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32_FUSE) {
  1709. using namespace conv_bias;
  1710. std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args(
  1711. {3}, 2, false, false, false, false, false, true, true, false);
  1712. #if MEGDNN_AARCH64
  1713. check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
  1714. #elif MEGDNN_ARMV7
  1715. check_conv_bias(args, handle(), "IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12");
  1716. #endif
  1717. }
  1718. /***************************** Conv1x1 Algo Test ***********************/
  1719. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) {
  1720. using namespace conv_bias;
  1721. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, false);
  1722. #if MEGDNN_AARCH64
  1723. check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32K8X12X1:24");
  1724. #elif MEGDNN_ARMV7
  1725. check_conv_bias(args, handle(), "CONV1x1:ARMV7_F32:48");
  1726. #endif
  1727. std::vector<conv_bias::TestArg> gemv_args;
  1728. for (auto&& arg : args)
  1729. if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
  1730. gemv_args.emplace_back(arg);
  1731. }
  1732. check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV");
  1733. }
  1734. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) {
  1735. using namespace conv_bias;
  1736. std::vector<conv_bias::TestArg> args =
  1737. get_nchw44_conv_bias_args({1}, 1, true, false, false);
  1738. #if MEGDNN_AARCH64
  1739. check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32_MK4_K8X12X1:24");
  1740. #elif MEGDNN_ARMV7
  1741. check_conv_bias(args, handle(), "CONV1x1:ARMV7_F32_MK4_PACK_4X12:24");
  1742. #endif
  1743. }
  1744. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_NO_PACK_F32) {
  1745. using namespace conv_bias;
  1746. std::vector<conv_bias::TestArg> args =
  1747. get_nchw44_conv_bias_args({1}, 1, true, false, false);
  1748. std::vector<conv_bias::TestArg> args_of_4;
  1749. for (auto&& arg : args) {
  1750. if (arg.src.shape[2] * arg.src.shape[3] % 4 == 0) {
  1751. args_of_4.push_back(arg);
  1752. }
  1753. }
  1754. #if MEGDNN_AARCH64
  1755. check_conv_bias(args_of_4, handle(), "CONV1x1:AARCH64_F32_MK4_4x16:24");
  1756. #elif MEGDNN_ARMV7
  1757. check_conv_bias(args_of_4, handle(), "CONV1x1:ARMV7_F32_MK4_4x8:48");
  1758. #endif
  1759. }
  1760. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  1761. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16) {
  1762. using namespace conv_bias;
  1763. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, false);
  1764. NormalRNG rng(1);
  1765. #if MEGDNN_AARCH64
  1766. checker_conv_bias(args, handle(), &rng, 0.03, dtype::Float16{},
  1767. dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
  1768. "CONV1x1:AARCH64_F16_K8X24X1:48");
  1769. #elif MEGDNN_ARMV7
  1770. checker_conv_bias(args, handle(), &rng, 0.03, dtype::Float16{},
  1771. dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
  1772. "CONV1x1:AARCH32_F16_K4X16X1:24");
  1773. #endif
  1774. std::vector<conv_bias::TestArg> gemv_args;
  1775. for (auto&& arg : args)
  1776. if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
  1777. gemv_args.emplace_back(arg);
  1778. }
  1779. check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV");
  1780. }
  1781. #endif
  1782. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM) {
  1783. UniformIntRNG rng{-50, 50};
  1784. float epsilon = 0.001;
  1785. std::vector<conv_bias::TestArg> args =
  1786. get_conv_bias_1x1_args(false, false, true, true);
  1787. #define cb(name) \
  1788. checker_conv_bias(args, handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1789. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1790. dtype::QuantizedS8(60.25f), name);
  1791. #if MEGDNN_AARCH64
  1792. #if __ARM_FEATURE_DOTPROD
  1793. cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:24");
  1794. #else
  1795. cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24");
  1796. cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:48");
  1797. #endif
  1798. #elif MEGDNN_ARMV7
  1799. epsilon = 1;
  1800. cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:48");
  1801. #endif
  1802. #undef cb
  1803. std::vector<conv_bias::TestArg> gemv_args;
  1804. for (auto&& arg : args)
  1805. if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
  1806. gemv_args.emplace_back(arg);
  1807. }
  1808. checker_conv_bias(gemv_args, handle(), &rng, epsilon,
  1809. dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
  1810. dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f),
  1811. "CONV1x1_GEMV");
  1812. }
  1813. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1814. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) {
  1815. UniformIntRNG rng{-50, 50};
  1816. std::vector<conv_bias::TestArg> args =
  1817. get_conv_bias_1x1_args(false, false, true, true);
  1818. #define cb(name) \
  1819. checker_conv_bias(args, handle(), &rng, epsilon, \
  1820. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1821. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1822. dtype::QuantizedS32(1.2 * 1.3), \
  1823. dtype::Quantized8Asymm(50.3f, (uint8_t)120), name);
  1824. float epsilon = 0.001;
  1825. #if MEGDNN_AARCH64
  1826. #if __ARM_FEATURE_DOTPROD
  1827. cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:48");
  1828. #else
  1829. cb("CONV1x1:AARCH64_QUINT8_K8X8X8:24");
  1830. #endif
  1831. #elif MEGDNN_ARMV7
  1832. epsilon = 1;
  1833. cb("CONV1x1:ARMV7_QUINT8_K4X8X8:48");
  1834. #endif
  1835. #undef cb
  1836. std::vector<conv_bias::TestArg> gemv_args;
  1837. for (auto&& arg : args)
  1838. if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
  1839. gemv_args.emplace_back(arg);
  1840. }
  1841. checker_conv_bias(gemv_args, handle(), &rng, epsilon,
  1842. dtype::Quantized8Asymm(1.2f, (uint8_t)125),
  1843. dtype::Quantized8Asymm(1.3f, (uint8_t)129),
  1844. dtype::QuantizedS32(1.2 * 1.3),
  1845. dtype::Quantized8Asymm(50.3f, (uint8_t)120),
  1846. "CONV1x1_GEMV");
  1847. }
  1848. #endif
  1849. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1850. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32) {
  1851. NormalRNG rng(128.f);
  1852. float epsilon = 0.001;
  1853. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true);
  1854. #define cb(name) \
  1855. checker_conv_bias(args, handle(), &rng, epsilon, \
  1856. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1857. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1858. dtype::QuantizedS32(1.2 * 1.3), {}, name);
  1859. #if MEGDNN_AARCH64
  1860. #if __ARM_FEATURE_DOTPROD
  1861. cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:24");
  1862. #else
  1863. cb("CONV1x1:AARCH64_QUINT8_K8X8X8:48");
  1864. #endif
  1865. #elif MEGDNN_ARMV7
  1866. #if __ARM_FEATURE_DOTPROD
  1867. cb("CONV1x1:AARCH32_QUINT8_K4X8X4:48");
  1868. #endif
  1869. cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24");
  1870. #endif
  1871. #undef cb
  1872. std::vector<conv_bias::TestArg> gemv_args;
  1873. for (auto&& arg : args)
  1874. if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
  1875. gemv_args.emplace_back(arg);
  1876. }
  1877. checker_conv_bias(gemv_args, handle(), &rng, epsilon,
  1878. dtype::Quantized8Asymm(1.2f, (uint8_t)125),
  1879. dtype::Quantized8Asymm(1.3f, (uint8_t)129),
  1880. dtype::QuantizedS32(1.2 * 1.3), {}, "CONV1x1_GEMV");
  1881. }
  1882. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) {
  1883. UniformIntRNG rng{-50, 50};
  1884. float epsilon = 0.001;
  1885. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true);
  1886. #define cb(name) \
  1887. checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, \
  1888. dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name);
  1889. #if MEGDNN_AARCH64
  1890. cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24");
  1891. cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24");
  1892. #elif MEGDNN_ARMV7
  1893. cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24");
  1894. cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48");
  1895. #endif
  1896. cb("CONV1x1:ARM_COMMON_INT8X8X16:48");
  1897. #undef cb
  1898. std::vector<conv_bias::TestArg> gemv_args;
  1899. for (auto&& arg : args)
  1900. if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
  1901. gemv_args.emplace_back(arg);
  1902. }
  1903. checker_conv_bias(gemv_args, handle(), &rng, epsilon, dtype::Int8{},
  1904. dtype::Int8{}, dtype::Int16{}, dtype::Int16{},
  1905. "CONV1x1_GEMV");
  1906. }
  1907. #endif
  1908. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) {
  1909. using namespace conv_bias;
  1910. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true);
  1911. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1912. #if MEGDNN_AARCH64
  1913. #if __ARM_FEATURE_DOTPROD
  1914. cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:48");
  1915. #else
  1916. cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24");
  1917. cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:24");
  1918. #endif
  1919. #elif MEGDNN_ARMV7
  1920. #if __ARM_FEATURE_DOTPROD
  1921. cb("CONV1x1:AARCH32_INT8_K6X8X4:48");
  1922. #endif
  1923. cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:24");
  1924. #endif
  1925. #if MEGDNN_ARMV7
  1926. cb("CONV1x1:ARMV7_INT8X8X32_K4X2X16:48");
  1927. #endif
  1928. #undef cb
  1929. std::vector<conv_bias::TestArg> gemv_args;
  1930. for (auto&& arg : args)
  1931. if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
  1932. gemv_args.emplace_back(arg);
  1933. }
  1934. checker_conv_bias_mul_int8x8x32(gemv_args, handle(), "CONV1x1_GEMV");
  1935. }
  1936. #ifndef __ARM_FEATURE_DOTPROD
  1937. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) {
  1938. using namespace conv_bias;
  1939. std::vector<conv_bias::TestArg> args =
  1940. get_nchw44_conv_bias_args({1}, 1, true, true, true);
  1941. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1942. #if MEGDNN_AARCH64
  1943. cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24");
  1944. #elif MEGDNN_ARMV7
  1945. cb("CONV1x1:ARMV7_INT8X8X32_MK4_4X2X16:24");
  1946. #endif
  1947. #undef cb
  1948. UniformIntRNG rng{-50, 50};
  1949. float epsilon = 0.001;
  1950. #define cb(name) \
  1951. checker_conv_bias(get_nchw44_conv_bias_args({1}, 1, true, false, false), \
  1952. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1953. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1954. dtype::QuantizedS8(60.25f), name);
  1955. #if MEGDNN_AARCH64
  1956. cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24");
  1957. #elif MEGDNN_ARMV7
  1958. cb("CONV1x1:ARMV7_INT8X8X32_MK4_4X2X16:24");
  1959. #endif
  1960. #undef cb
  1961. }
  1962. #endif
  1963. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台