You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

avx2_chanwise_kern.cpp 68 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593
  1. /**
  2. * \file src/x86/conv_bias/int8/avx2_chanwise_kern.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/x86/conv_bias/int8/avx2_chanwise_kern.h"
  13. #include <immintrin.h>
  14. #include "src/common/unroll_macro.h"
  15. #include "src/x86/conv_bias/int8/common_helper.h"
  16. #include "src/x86/elemwise_op.h"
  17. #ifdef WIN32CMAKE
  18. #include <smmintrin.h>
  19. #endif
  20. namespace megdnn {
  21. namespace x86 {
  22. namespace avx2_chanwise_stride1 {
  23. #define load_filter(i) __m128i k_##i = _mm_set1_epi8(*(filter + i));
  24. #define load_src0(i) \
  25. __m256i cvt16_src##i##0 = _mm256_cvtepi8_epi16_from_ptr(r##i);
  26. #define load_src1(i) \
  27. __m256i cvt16_src##i##1 = _mm256_cvtepi8_epi16_from_ptr(r##i + 1);
  28. #define load_src2(i) \
  29. __m256i cvt16_src##i##2 = _mm256_cvtepi8_epi16_from_ptr(r##i + 2);
  30. #define load_src3(i) \
  31. __m256i cvt16_src##i##3 = _mm256_cvtepi8_epi16_from_ptr(r##i + 3);
  32. #define load_src4(i) \
  33. __m256i cvt16_src##i##4 = _mm256_cvtepi8_epi16_from_ptr(r##i + 4);
  34. #define load_src5(i) \
  35. __m256i cvt16_src##i##5 = _mm256_cvtepi8_epi16_from_ptr(r##i + 5);
  36. #define load_src6(i) \
  37. __m256i cvt16_src##i##6 = _mm256_cvtepi8_epi16_from_ptr(r##i + 6);
  38. #define load_src7(i) \
  39. __m256i cvt16_src##i##7 = _mm256_cvtepi8_epi16_from_ptr(r##i + 7);
  40. template <BiasMode bias_mode, bool is_quantized, typename Op>
  41. void avx2_chanwise_direct_stride1_2x2_int8(const int8_t* src,
  42. const int8_t* filter,
  43. const int32_t* bias, int32_t* temp,
  44. int8_t* dst, const size_t IH,
  45. const size_t IW, const size_t OH,
  46. const size_t OW, const Op& op) {
  47. size_t tail_step = IW - OW;
  48. int8_t* dst0 = dst;
  49. int8_t* dst1 = dst + OW;
  50. int32_t* out_ptr0 = temp;
  51. int32_t* out_ptr1 = temp + OW;
  52. const int8_t* r0 = src;
  53. const int8_t* r1 = src + IW;
  54. const int8_t* r2 = src + 2 * IW;
  55. UNROLL_CALL0(4, load_filter)
  56. #define pack_filter(i, j) __m128i k_##i##j = _mm_unpacklo_epi8(k_##i, k_##j)
  57. pack_filter(0, 1);
  58. pack_filter(2, 3);
  59. __m256i bias_val;
  60. if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
  61. bias_val = _mm256_set1_epi32(*(bias));
  62. } else {
  63. bias_val = _mm256_set1_epi32(0);
  64. }
  65. #define cvt_filter(i, j) __m256i filter_##i##j = _mm256_cvtepi8_epi16(k_##i##j)
  66. cvt_filter(0, 1);
  67. cvt_filter(2, 3);
  68. size_t width = OW >> 4;
  69. size_t h = 0;
  70. for (; h + 1 < OH; h += 2) {
  71. size_t w = 0;
  72. for (; w < width; w++) {
  73. UNROLL_CALL0(3, load_src0)
  74. UNROLL_CALL0(3, load_src1)
  75. __m256i sum0_odd, sum0_even, sum1_odd, sum1_even;
  76. __m256i tmp0_odd, tmp0_even, tmp1_odd, tmp1_even, tmp2_odd,
  77. tmp2_even, tmp3_odd, tmp3_even;
  78. tmp0_odd = _mm256_madd_epi16(cvt16_src00, filter_01);
  79. tmp0_even = _mm256_madd_epi16(cvt16_src01, filter_01);
  80. tmp1_odd = _mm256_madd_epi16(cvt16_src10, filter_23);
  81. tmp1_even = _mm256_madd_epi16(cvt16_src11, filter_23);
  82. tmp3_odd = _mm256_madd_epi16(cvt16_src10, filter_01);
  83. tmp3_even = _mm256_madd_epi16(cvt16_src11, filter_01);
  84. tmp2_odd = _mm256_madd_epi16(cvt16_src20, filter_23);
  85. tmp2_even = _mm256_madd_epi16(cvt16_src21, filter_23);
  86. sum0_odd = _mm256_add_epi32(tmp0_odd, tmp1_odd);
  87. sum0_even = _mm256_add_epi32(tmp0_even, tmp1_even);
  88. __m256i sum_odd = _mm256_unpacklo_epi32(sum0_odd, sum0_even);
  89. __m256i sum_even = _mm256_unpackhi_epi32(sum0_odd, sum0_even);
  90. //! switch_mask_low = {00100000} = 32
  91. //! switch_mask_high = {00110001} = 49
  92. __m256i sum_left = _mm256_permute2f128_si256(sum_odd, sum_even, 32);
  93. __m256i sum_right =
  94. _mm256_permute2f128_si256(sum_odd, sum_even, 49);
  95. sum_left = _mm256_add_epi32(sum_left, bias_val);
  96. sum_right = _mm256_add_epi32(sum_right, bias_val);
  97. if (is_quantized) {
  98. op({{sum_left, sum_right}}, reinterpret_cast<dt_qint8*>(dst0));
  99. } else {
  100. _mm256_storeu_si256((__m256i*)(out_ptr0), sum_left);
  101. _mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right);
  102. }
  103. sum1_odd = _mm256_add_epi32(tmp3_odd, tmp2_odd);
  104. sum1_even = _mm256_add_epi32(tmp3_even, tmp2_even);
  105. __m256i sum_1_odd = _mm256_unpacklo_epi32(sum1_odd, sum1_even);
  106. __m256i sum_1_even = _mm256_unpackhi_epi32(sum1_odd, sum1_even);
  107. __m256i sum_1_left =
  108. _mm256_permute2f128_si256(sum_1_odd, sum_1_even, 32);
  109. __m256i sum_1_right =
  110. _mm256_permute2f128_si256(sum_1_odd, sum_1_even, 49);
  111. sum_1_left = _mm256_add_epi32(sum_1_left, bias_val);
  112. sum_1_right = _mm256_add_epi32(sum_1_right, bias_val);
  113. if (is_quantized) {
  114. op({{sum_1_left, sum_1_right}},
  115. reinterpret_cast<dt_qint8*>(dst1));
  116. } else {
  117. _mm256_storeu_si256((__m256i*)(out_ptr1), sum_1_left);
  118. _mm256_storeu_si256((__m256i*)(out_ptr1 + 8), sum_1_right);
  119. }
  120. r0 += 16;
  121. r1 += 16;
  122. r2 += 16;
  123. dst0 += 16;
  124. dst1 += 16;
  125. out_ptr0 += 16;
  126. out_ptr1 += 16;
  127. }
  128. r0 += tail_step + IW;
  129. r1 += tail_step + IW;
  130. r2 += tail_step + IW;
  131. dst0 += OW;
  132. dst1 += OW;
  133. out_ptr0 += OW;
  134. out_ptr1 += OW;
  135. }
  136. for (; h < OH; h++) {
  137. size_t w = 0;
  138. for (; w < width; w++) {
  139. UNROLL_CALL0(2, load_src0)
  140. UNROLL_CALL0(2, load_src1)
  141. __m256i sum0_odd, sum0_even;
  142. __m256i tmp0_odd, tmp0_even, tmp1_odd, tmp1_even;
  143. tmp0_odd = _mm256_madd_epi16(cvt16_src00, filter_01);
  144. tmp0_even = _mm256_madd_epi16(cvt16_src01, filter_01);
  145. tmp1_odd = _mm256_madd_epi16(cvt16_src10, filter_23);
  146. tmp1_even = _mm256_madd_epi16(cvt16_src11, filter_23);
  147. sum0_odd = _mm256_add_epi32(tmp0_odd, tmp1_odd);
  148. sum0_even = _mm256_add_epi32(tmp0_even, tmp1_even);
  149. __m256i sum_odd = _mm256_unpacklo_epi32(sum0_odd, sum0_even);
  150. __m256i sum_even = _mm256_unpackhi_epi32(sum0_odd, sum0_even);
  151. __m256i sum_left = _mm256_permute2f128_si256(sum_odd, sum_even, 32);
  152. __m256i sum_right =
  153. _mm256_permute2f128_si256(sum_odd, sum_even, 49);
  154. sum_left = _mm256_add_epi32(sum_left, bias_val);
  155. sum_right = _mm256_add_epi32(sum_right, bias_val);
  156. if (is_quantized) {
  157. op({{sum_left, sum_right}}, reinterpret_cast<dt_qint8*>(dst0));
  158. } else {
  159. _mm256_storeu_si256((__m256i*)(out_ptr0), sum_left);
  160. _mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right);
  161. }
  162. r0 += 16;
  163. r1 += 16;
  164. dst0 += 16;
  165. out_ptr0 += 16;
  166. }
  167. r0 += tail_step;
  168. r1 += tail_step;
  169. }
  170. MEGDNN_MARK_USED_VAR(IH);
  171. #undef pack_filter
  172. #undef cvt_filter
  173. }
  174. template <BiasMode bias_mode, bool is_quantized, typename Op>
  175. void avx2_chanwise_direct_stride1_3x3_int8(const int8_t* src,
  176. const int8_t* filter,
  177. const int32_t* bias, int32_t* temp,
  178. int8_t* dst, const size_t IH,
  179. const size_t IW, const size_t OH,
  180. const size_t OW, const Op& op) {
  181. MEGDNN_MARK_USED_VAR(IH);
  182. size_t tail_step = IW - OW;
  183. int32_t* out_ptr0 = temp;
  184. int32_t* out_ptr1 = temp + OW;
  185. int8_t* dst0 = dst;
  186. int8_t* dst1 = dst + OW;
  187. const int8_t* r0 = src;
  188. const int8_t* r1 = src + IW;
  189. const int8_t* r2 = src + 2 * IW;
  190. const int8_t* r3 = src + 3 * IW;
  191. uint8_t fill_zero = 0;
  192. UNROLL_CALL0(9, load_filter)
  193. __m128i k_fill = _mm_set1_epi8(fill_zero);
  194. __m128i k01 = _mm_unpacklo_epi8(k_0, k_1);
  195. __m128i k20 = _mm_unpacklo_epi8(k_2, k_fill);
  196. __m128i k34 = _mm_unpacklo_epi8(k_3, k_4);
  197. __m128i k50 = _mm_unpacklo_epi8(k_5, k_fill);
  198. __m128i k67 = _mm_unpacklo_epi8(k_6, k_7);
  199. __m128i k80 = _mm_unpacklo_epi8(k_8, k_fill);
  200. __m256i bias_val;
  201. if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
  202. bias_val = _mm256_set1_epi32(*(bias));
  203. } else {
  204. bias_val = _mm256_set1_epi32(0);
  205. }
  206. //! cvt i8 --> i16
  207. __m256i filter_01 = _mm256_cvtepi8_epi16(k01);
  208. __m256i filter_20 = _mm256_cvtepi8_epi16(k20);
  209. __m256i filter_34 = _mm256_cvtepi8_epi16(k34);
  210. __m256i filter_50 = _mm256_cvtepi8_epi16(k50);
  211. __m256i filter_67 = _mm256_cvtepi8_epi16(k67);
  212. __m256i filter_80 = _mm256_cvtepi8_epi16(k80);
  213. size_t width = OW >> 4;
  214. size_t h = 0;
  215. for (; h + 1 < OH; h += 2) {
  216. size_t w = 0;
  217. for (; w < width; w++) {
  218. UNROLL_CALL0(4, load_src0)
  219. UNROLL_CALL0(4, load_src1)
  220. UNROLL_CALL0(4, load_src2)
  221. UNROLL_CALL0(4, load_src3)
  222. __m256i sum00_odd, sum00_even, sum11_odd, sum11_even, sum22_odd,
  223. sum22_even;
  224. __m256i sum11_odd_01, sum11_even_01, sum22_odd_01, sum22_even_01,
  225. sum33_odd, sum33_even;
  226. __m256i temp0, temp1;
  227. temp0 = _mm256_madd_epi16(cvt16_src00, filter_01);
  228. temp1 = _mm256_madd_epi16(cvt16_src02, filter_20);
  229. sum00_odd = _mm256_add_epi32(temp0, temp1);
  230. temp0 = _mm256_madd_epi16(cvt16_src01, filter_01);
  231. temp1 = _mm256_madd_epi16(cvt16_src03, filter_20);
  232. sum00_even = _mm256_add_epi32(temp0, temp1);
  233. temp0 = _mm256_madd_epi16(cvt16_src10, filter_34);
  234. temp1 = _mm256_madd_epi16(cvt16_src12, filter_50);
  235. sum11_odd = _mm256_add_epi32(temp0, temp1);
  236. temp0 = _mm256_madd_epi16(cvt16_src11, filter_34);
  237. temp1 = _mm256_madd_epi16(cvt16_src13, filter_50);
  238. sum11_even = _mm256_add_epi32(temp0, temp1);
  239. temp0 = _mm256_madd_epi16(cvt16_src10, filter_01);
  240. temp1 = _mm256_madd_epi16(cvt16_src12, filter_20);
  241. sum11_odd_01 = _mm256_add_epi32(temp0, temp1);
  242. temp0 = _mm256_madd_epi16(cvt16_src11, filter_01);
  243. temp1 = _mm256_madd_epi16(cvt16_src13, filter_20);
  244. sum11_even_01 = _mm256_add_epi32(temp0, temp1);
  245. temp0 = _mm256_madd_epi16(cvt16_src20, filter_67);
  246. temp1 = _mm256_madd_epi16(cvt16_src22, filter_80);
  247. sum22_odd = _mm256_add_epi32(temp0, temp1);
  248. temp0 = _mm256_madd_epi16(cvt16_src21, filter_67);
  249. temp1 = _mm256_madd_epi16(cvt16_src23, filter_80);
  250. sum22_even = _mm256_add_epi32(temp0, temp1);
  251. temp0 = _mm256_madd_epi16(cvt16_src20, filter_34);
  252. temp1 = _mm256_madd_epi16(cvt16_src22, filter_50);
  253. sum22_odd_01 = _mm256_add_epi32(temp0, temp1);
  254. temp0 = _mm256_madd_epi16(cvt16_src21, filter_34);
  255. temp1 = _mm256_madd_epi16(cvt16_src23, filter_50);
  256. sum22_even_01 = _mm256_add_epi32(temp0, temp1);
  257. temp0 = _mm256_madd_epi16(cvt16_src30, filter_67);
  258. temp1 = _mm256_madd_epi16(cvt16_src32, filter_80);
  259. sum33_odd = _mm256_add_epi32(temp0, temp1);
  260. temp0 = _mm256_madd_epi16(cvt16_src31, filter_67);
  261. temp1 = _mm256_madd_epi16(cvt16_src33, filter_80);
  262. sum33_even = _mm256_add_epi32(temp0, temp1);
  263. sum00_odd = _mm256_add_epi32(sum00_odd, sum11_odd);
  264. sum00_odd = _mm256_add_epi32(sum00_odd, sum22_odd);
  265. sum00_even = _mm256_add_epi32(sum00_even, sum11_even);
  266. sum00_even = _mm256_add_epi32(sum00_even, sum22_even);
  267. __m256i sum_odd = _mm256_unpacklo_epi32(sum00_odd, sum00_even);
  268. __m256i sum_even = _mm256_unpackhi_epi32(sum00_odd, sum00_even);
  269. __m256i sum_left = _mm256_permute2f128_si256(sum_odd, sum_even, 32);
  270. __m256i sum_right =
  271. _mm256_permute2f128_si256(sum_odd, sum_even, 49);
  272. sum_left = _mm256_add_epi32(sum_left, bias_val);
  273. sum_right = _mm256_add_epi32(sum_right, bias_val);
  274. if (is_quantized) {
  275. op({{sum_left, sum_right}}, reinterpret_cast<dt_qint8*>(dst0));
  276. } else {
  277. _mm256_storeu_si256((__m256i*)(out_ptr0), sum_left);
  278. _mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right);
  279. }
  280. sum11_odd_01 = _mm256_add_epi32(sum11_odd_01, sum22_odd_01);
  281. sum11_odd_01 = _mm256_add_epi32(sum11_odd_01, sum33_odd);
  282. sum11_even_01 = _mm256_add_epi32(sum11_even_01, sum22_even_01);
  283. sum11_even_01 = _mm256_add_epi32(sum11_even_01, sum33_even);
  284. __m256i sum_oh1_odd =
  285. _mm256_unpacklo_epi32(sum11_odd_01, sum11_even_01);
  286. __m256i sum_oh1_even =
  287. _mm256_unpackhi_epi32(sum11_odd_01, sum11_even_01);
  288. __m256i sum1_left =
  289. _mm256_permute2f128_si256(sum_oh1_odd, sum_oh1_even, 32);
  290. __m256i sum1_right =
  291. _mm256_permute2f128_si256(sum_oh1_odd, sum_oh1_even, 49);
  292. sum1_left = _mm256_add_epi32(sum1_left, bias_val);
  293. sum1_right = _mm256_add_epi32(sum1_right, bias_val);
  294. if (is_quantized) {
  295. op({{sum1_left, sum1_right}},
  296. reinterpret_cast<dt_qint8*>(dst1));
  297. } else {
  298. _mm256_storeu_si256((__m256i*)(out_ptr1), sum1_left);
  299. _mm256_storeu_si256((__m256i*)(out_ptr1 + 8), sum1_right);
  300. }
  301. r0 += 16;
  302. r1 += 16;
  303. r2 += 16;
  304. r3 += 16;
  305. dst0 += 16;
  306. dst1 += 16;
  307. out_ptr0 += 16;
  308. out_ptr1 += 16;
  309. }
  310. r0 += tail_step + IW;
  311. r1 += tail_step + IW;
  312. r2 += tail_step + IW;
  313. r3 += tail_step + IW;
  314. dst0 += OW;
  315. dst1 += OW;
  316. out_ptr0 += OW;
  317. out_ptr1 += OW;
  318. }
  319. for (; h < OH; h++) {
  320. size_t w = 0;
  321. for (; w < width; w++) {
  322. UNROLL_CALL0(3, load_src0)
  323. UNROLL_CALL0(3, load_src1)
  324. UNROLL_CALL0(3, load_src2)
  325. UNROLL_CALL0(3, load_src3)
  326. __m256i sum00_odd, sum00_even, sum11_odd, sum11_even, sum22_odd,
  327. sum22_even;
  328. __m256i temp0, temp1;
  329. temp0 = _mm256_madd_epi16(cvt16_src00, filter_01);
  330. temp1 = _mm256_madd_epi16(cvt16_src02, filter_20);
  331. sum00_odd = _mm256_add_epi32(temp0, temp1);
  332. temp0 = _mm256_madd_epi16(cvt16_src01, filter_01);
  333. temp1 = _mm256_madd_epi16(cvt16_src03, filter_20);
  334. sum00_even = _mm256_add_epi32(temp0, temp1);
  335. temp0 = _mm256_madd_epi16(cvt16_src10, filter_34);
  336. temp1 = _mm256_madd_epi16(cvt16_src12, filter_50);
  337. sum11_odd = _mm256_add_epi32(temp0, temp1);
  338. temp0 = _mm256_madd_epi16(cvt16_src11, filter_34);
  339. temp1 = _mm256_madd_epi16(cvt16_src13, filter_50);
  340. sum11_even = _mm256_add_epi32(temp0, temp1);
  341. temp0 = _mm256_madd_epi16(cvt16_src20, filter_67);
  342. temp1 = _mm256_madd_epi16(cvt16_src22, filter_80);
  343. sum22_odd = _mm256_add_epi32(temp0, temp1);
  344. temp0 = _mm256_madd_epi16(cvt16_src21, filter_67);
  345. temp1 = _mm256_madd_epi16(cvt16_src23, filter_80);
  346. sum22_even = _mm256_add_epi32(temp0, temp1);
  347. sum00_odd = _mm256_add_epi32(sum00_odd, sum11_odd);
  348. sum00_odd = _mm256_add_epi32(sum00_odd, sum22_odd);
  349. sum00_even = _mm256_add_epi32(sum00_even, sum11_even);
  350. sum00_even = _mm256_add_epi32(sum00_even, sum22_even);
  351. __m256i sum_odd = _mm256_unpacklo_epi32(sum00_odd, sum00_even);
  352. __m256i sum_even = _mm256_unpackhi_epi32(sum00_odd, sum00_even);
  353. __m256i sum_left = _mm256_permute2f128_si256(sum_odd, sum_even, 32);
  354. __m256i sum_right =
  355. _mm256_permute2f128_si256(sum_odd, sum_even, 49);
  356. sum_left = _mm256_add_epi32(sum_left, bias_val);
  357. sum_right = _mm256_add_epi32(sum_right, bias_val);
  358. if (is_quantized) {
  359. op({{sum_left, sum_right}}, reinterpret_cast<dt_qint8*>(dst0));
  360. } else {
  361. _mm256_storeu_si256((__m256i*)(out_ptr0), sum_left);
  362. _mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right);
  363. }
  364. r0 += 16;
  365. r1 += 16;
  366. r2 += 16;
  367. dst0 += 16;
  368. out_ptr0 += 16;
  369. }
  370. r0 += tail_step;
  371. r1 += tail_step;
  372. r2 += tail_step;
  373. }
  374. }
  375. template <BiasMode bias_mode, bool is_quantized, typename Op>
  376. void avx2_chanwise_direct_stride1_5x5_int8(const int8_t* src,
  377. const int8_t* filter,
  378. const int32_t* bias, int32_t* temp,
  379. int8_t* dst, const size_t IH,
  380. const size_t IW, const size_t OH,
  381. const size_t OW, const Op& op) {
  382. MEGDNN_MARK_USED_VAR(IH);
  383. size_t tail_step = IW - OW;
  384. int8_t* dst0 = dst;
  385. int8_t* dst1 = dst + OW;
  386. int32_t* out_ptr0 = temp;
  387. int32_t* out_ptr1 = temp + OW;
  388. const int8_t* r0 = src;
  389. const int8_t* r1 = src + IW;
  390. const int8_t* r2 = src + 2 * IW;
  391. const int8_t* r3 = src + 3 * IW;
  392. const int8_t* r4 = src + 4 * IW;
  393. const int8_t* r5 = src + 5 * IW;
  394. uint8_t fill_zero = 0;
  395. UNROLL_CALL0(25, load_filter)
  396. __m128i k_fill = _mm_set1_epi8(fill_zero);
  397. __m128i k01 = _mm_unpacklo_epi8(k_0, k_1);
  398. __m128i k23 = _mm_unpacklo_epi8(k_2, k_3);
  399. __m128i k40 = _mm_unpacklo_epi8(k_4, k_fill);
  400. __m128i k56 = _mm_unpacklo_epi8(k_5, k_6);
  401. __m128i k78 = _mm_unpacklo_epi8(k_7, k_8);
  402. __m128i k90 = _mm_unpacklo_epi8(k_9, k_fill);
  403. __m128i k1011 = _mm_unpacklo_epi8(k_10, k_11);
  404. __m128i k1213 = _mm_unpacklo_epi8(k_12, k_13);
  405. __m128i k140 = _mm_unpacklo_epi8(k_14, k_fill);
  406. __m128i k1516 = _mm_unpacklo_epi8(k_15, k_16);
  407. __m128i k1718 = _mm_unpacklo_epi8(k_17, k_18);
  408. __m128i k190 = _mm_unpacklo_epi8(k_19, k_fill);
  409. __m128i k2021 = _mm_unpacklo_epi8(k_20, k_21);
  410. __m128i k2223 = _mm_unpacklo_epi8(k_22, k_23);
  411. __m128i k240 = _mm_unpacklo_epi8(k_24, k_fill);
  412. __m256i bias_val;
  413. //! load bias
  414. if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
  415. bias_val = _mm256_set1_epi32(*(bias));
  416. } else {
  417. bias_val = _mm256_set1_epi32(0);
  418. }
  419. //! cvt i8 --> i16
  420. __m256i filter_01 = _mm256_cvtepi8_epi16(k01);
  421. __m256i filter_23 = _mm256_cvtepi8_epi16(k23);
  422. __m256i filter_40 = _mm256_cvtepi8_epi16(k40);
  423. __m256i filter_56 = _mm256_cvtepi8_epi16(k56);
  424. __m256i filter_78 = _mm256_cvtepi8_epi16(k78);
  425. __m256i filter_90 = _mm256_cvtepi8_epi16(k90);
  426. __m256i filter_1011 = _mm256_cvtepi8_epi16(k1011);
  427. __m256i filter_1213 = _mm256_cvtepi8_epi16(k1213);
  428. __m256i filter_140 = _mm256_cvtepi8_epi16(k140);
  429. __m256i filter_1516 = _mm256_cvtepi8_epi16(k1516);
  430. __m256i filter_1718 = _mm256_cvtepi8_epi16(k1718);
  431. __m256i filter_190 = _mm256_cvtepi8_epi16(k190);
  432. __m256i filter_2021 = _mm256_cvtepi8_epi16(k2021);
  433. __m256i filter_2223 = _mm256_cvtepi8_epi16(k2223);
  434. __m256i filter_240 = _mm256_cvtepi8_epi16(k240);
  435. size_t width = OW >> 4;
  436. size_t h = 0;
  437. for (; h + 1 < OH; h += 2) {
  438. size_t w = 0;
  439. for (; w < width; w++) {
  440. UNROLL_CALL0(6, load_src0)
  441. UNROLL_CALL0(6, load_src1)
  442. UNROLL_CALL0(6, load_src2)
  443. UNROLL_CALL0(6, load_src3)
  444. UNROLL_CALL0(6, load_src4)
  445. UNROLL_CALL0(6, load_src5)
  446. __m256i sum0_odd, sum0_even, sum1_odd, sum1_even, sum2_odd,
  447. sum2_even, sum3_odd, sum3_even, sum4_odd, sum4_even;
  448. __m256i sum10_odd, sum10_even, sum20_odd, sum20_even, sum30_odd,
  449. sum30_even, sum40_odd, sum40_even, sum5_odd, sum5_even;
  450. //! cal src0
  451. __m256i dot1, dot2, dot3;
  452. dot1 = _mm256_madd_epi16(cvt16_src00, filter_01);
  453. dot2 = _mm256_madd_epi16(cvt16_src02, filter_23);
  454. dot3 = _mm256_madd_epi16(cvt16_src04, filter_40);
  455. sum0_odd = _mm256_add_epi32(dot1, dot2);
  456. sum0_odd = _mm256_add_epi32(sum0_odd, dot3);
  457. dot1 = _mm256_madd_epi16(cvt16_src01, filter_01);
  458. dot2 = _mm256_madd_epi16(cvt16_src03, filter_23);
  459. dot3 = _mm256_madd_epi16(cvt16_src05, filter_40);
  460. sum0_even = _mm256_add_epi32(dot1, dot2);
  461. sum0_even = _mm256_add_epi32(sum0_even, dot3);
  462. //! cal src1
  463. dot1 = _mm256_madd_epi16(cvt16_src10, filter_56);
  464. dot2 = _mm256_madd_epi16(cvt16_src12, filter_78);
  465. dot3 = _mm256_madd_epi16(cvt16_src14, filter_90);
  466. sum1_odd = _mm256_add_epi32(dot1, dot2);
  467. sum1_odd = _mm256_add_epi32(sum1_odd, dot3);
  468. dot1 = _mm256_madd_epi16(cvt16_src11, filter_56);
  469. dot2 = _mm256_madd_epi16(cvt16_src13, filter_78);
  470. dot3 = _mm256_madd_epi16(cvt16_src15, filter_90);
  471. sum1_even = _mm256_add_epi32(dot1, dot2);
  472. sum1_even = _mm256_add_epi32(sum1_even, dot3);
  473. dot1 = _mm256_madd_epi16(cvt16_src10, filter_01);
  474. dot2 = _mm256_madd_epi16(cvt16_src12, filter_23);
  475. dot3 = _mm256_madd_epi16(cvt16_src14, filter_40);
  476. sum10_odd = _mm256_add_epi32(dot1, dot2);
  477. sum10_odd = _mm256_add_epi32(sum10_odd, dot3);
  478. dot1 = _mm256_madd_epi16(cvt16_src11, filter_01);
  479. dot2 = _mm256_madd_epi16(cvt16_src13, filter_23);
  480. dot3 = _mm256_madd_epi16(cvt16_src15, filter_40);
  481. sum10_even = _mm256_add_epi32(dot1, dot2);
  482. sum10_even = _mm256_add_epi32(sum10_even, dot3);
  483. //! cal src2
  484. dot1 = _mm256_madd_epi16(cvt16_src20, filter_1011);
  485. dot2 = _mm256_madd_epi16(cvt16_src22, filter_1213);
  486. dot3 = _mm256_madd_epi16(cvt16_src24, filter_140);
  487. sum2_odd = _mm256_add_epi32(dot1, dot2);
  488. sum2_odd = _mm256_add_epi32(sum2_odd, dot3);
  489. dot1 = _mm256_madd_epi16(cvt16_src21, filter_1011);
  490. dot2 = _mm256_madd_epi16(cvt16_src23, filter_1213);
  491. dot3 = _mm256_madd_epi16(cvt16_src25, filter_140);
  492. sum2_even = _mm256_add_epi32(dot1, dot2);
  493. sum2_even = _mm256_add_epi32(sum2_even, dot3);
  494. dot1 = _mm256_madd_epi16(cvt16_src20, filter_56);
  495. dot2 = _mm256_madd_epi16(cvt16_src22, filter_78);
  496. dot3 = _mm256_madd_epi16(cvt16_src24, filter_90);
  497. sum20_odd = _mm256_add_epi32(dot1, dot2);
  498. sum20_odd = _mm256_add_epi32(sum20_odd, dot3);
  499. dot1 = _mm256_madd_epi16(cvt16_src21, filter_56);
  500. dot2 = _mm256_madd_epi16(cvt16_src23, filter_78);
  501. dot3 = _mm256_madd_epi16(cvt16_src25, filter_90);
  502. sum20_even = _mm256_add_epi32(dot1, dot2);
  503. sum20_even = _mm256_add_epi32(sum20_even, dot3);
  504. //! cal src3
  505. dot1 = _mm256_madd_epi16(cvt16_src30, filter_1516);
  506. dot2 = _mm256_madd_epi16(cvt16_src32, filter_1718);
  507. dot3 = _mm256_madd_epi16(cvt16_src34, filter_190);
  508. sum3_odd = _mm256_add_epi32(dot1, dot2);
  509. sum3_odd = _mm256_add_epi32(sum3_odd, dot3);
  510. dot1 = _mm256_madd_epi16(cvt16_src31, filter_1516);
  511. dot2 = _mm256_madd_epi16(cvt16_src33, filter_1718);
  512. dot3 = _mm256_madd_epi16(cvt16_src35, filter_190);
  513. sum3_even = _mm256_add_epi32(dot1, dot2);
  514. sum3_even = _mm256_add_epi32(sum3_even, dot3);
  515. dot1 = _mm256_madd_epi16(cvt16_src30, filter_1011);
  516. dot2 = _mm256_madd_epi16(cvt16_src32, filter_1213);
  517. dot3 = _mm256_madd_epi16(cvt16_src34, filter_140);
  518. sum30_odd = _mm256_add_epi32(dot1, dot2);
  519. sum30_odd = _mm256_add_epi32(sum30_odd, dot3);
  520. dot1 = _mm256_madd_epi16(cvt16_src31, filter_1011);
  521. dot2 = _mm256_madd_epi16(cvt16_src33, filter_1213);
  522. dot3 = _mm256_madd_epi16(cvt16_src35, filter_140);
  523. sum30_even = _mm256_add_epi32(dot1, dot2);
  524. sum30_even = _mm256_add_epi32(sum30_even, dot3);
  525. //! cal src4
  526. dot1 = _mm256_madd_epi16(cvt16_src40, filter_2021);
  527. dot2 = _mm256_madd_epi16(cvt16_src42, filter_2223);
  528. dot3 = _mm256_madd_epi16(cvt16_src44, filter_240);
  529. sum4_odd = _mm256_add_epi32(dot1, dot2);
  530. sum4_odd = _mm256_add_epi32(sum4_odd, dot3);
  531. dot1 = _mm256_madd_epi16(cvt16_src41, filter_2021);
  532. dot2 = _mm256_madd_epi16(cvt16_src43, filter_2223);
  533. dot3 = _mm256_madd_epi16(cvt16_src45, filter_240);
  534. sum4_even = _mm256_add_epi32(dot1, dot2);
  535. sum4_even = _mm256_add_epi32(sum4_even, dot3);
  536. dot1 = _mm256_madd_epi16(cvt16_src40, filter_1516);
  537. dot2 = _mm256_madd_epi16(cvt16_src42, filter_1718);
  538. dot3 = _mm256_madd_epi16(cvt16_src44, filter_190);
  539. sum40_odd = _mm256_add_epi32(dot1, dot2);
  540. sum40_odd = _mm256_add_epi32(sum40_odd, dot3);
  541. dot1 = _mm256_madd_epi16(cvt16_src41, filter_1516);
  542. dot2 = _mm256_madd_epi16(cvt16_src43, filter_1718);
  543. dot3 = _mm256_madd_epi16(cvt16_src45, filter_190);
  544. sum40_even = _mm256_add_epi32(dot1, dot2);
  545. sum40_even = _mm256_add_epi32(sum40_even, dot3);
  546. //! cal src5
  547. dot1 = _mm256_madd_epi16(cvt16_src50, filter_2021);
  548. dot2 = _mm256_madd_epi16(cvt16_src52, filter_2223);
  549. dot3 = _mm256_madd_epi16(cvt16_src54, filter_240);
  550. sum5_odd = _mm256_add_epi32(dot1, dot2);
  551. sum5_odd = _mm256_add_epi32(sum5_odd, dot3);
  552. dot1 = _mm256_madd_epi16(cvt16_src51, filter_2021);
  553. dot2 = _mm256_madd_epi16(cvt16_src53, filter_2223);
  554. dot3 = _mm256_madd_epi16(cvt16_src55, filter_240);
  555. sum5_even = _mm256_add_epi32(dot1, dot2);
  556. sum5_even = _mm256_add_epi32(sum5_even, dot3);
  557. __m256i sum_odd, sum_even;
  558. sum_odd = _mm256_add_epi32(sum0_odd, sum1_odd);
  559. sum_odd = _mm256_add_epi32(sum_odd, sum2_odd);
  560. sum_odd = _mm256_add_epi32(sum_odd, sum3_odd);
  561. sum_odd = _mm256_add_epi32(sum_odd, sum4_odd);
  562. sum_even = _mm256_add_epi32(sum0_even, sum1_even);
  563. sum_even = _mm256_add_epi32(sum_even, sum2_even);
  564. sum_even = _mm256_add_epi32(sum_even, sum3_even);
  565. sum_even = _mm256_add_epi32(sum_even, sum4_even);
  566. __m256i sum_odd_0 = _mm256_unpacklo_epi32(sum_odd, sum_even);
  567. __m256i sum_even_0 = _mm256_unpackhi_epi32(sum_odd, sum_even);
  568. __m256i sum_left =
  569. _mm256_permute2f128_si256(sum_odd_0, sum_even_0, 32);
  570. __m256i sum_right =
  571. _mm256_permute2f128_si256(sum_odd_0, sum_even_0, 49);
  572. sum_left = _mm256_add_epi32(sum_left, bias_val);
  573. sum_right = _mm256_add_epi32(sum_right, bias_val);
  574. if (is_quantized) {
  575. op({{sum_left, sum_right}}, reinterpret_cast<dt_qint8*>(dst0));
  576. } else {
  577. _mm256_storeu_si256((__m256i*)(out_ptr0), sum_left);
  578. _mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right);
  579. }
  580. __m256i sum_odd_oh1, sum_even_oh1;
  581. sum_odd_oh1 = _mm256_add_epi32(sum10_odd, sum20_odd);
  582. sum_odd_oh1 = _mm256_add_epi32(sum_odd_oh1, sum30_odd);
  583. sum_odd_oh1 = _mm256_add_epi32(sum_odd_oh1, sum40_odd);
  584. sum_odd_oh1 = _mm256_add_epi32(sum_odd_oh1, sum5_odd);
  585. sum_even_oh1 = _mm256_add_epi32(sum10_even, sum20_even);
  586. sum_even_oh1 = _mm256_add_epi32(sum_even_oh1, sum30_even);
  587. sum_even_oh1 = _mm256_add_epi32(sum_even_oh1, sum40_even);
  588. sum_even_oh1 = _mm256_add_epi32(sum_even_oh1, sum5_even);
  589. __m256i sum_odd_1 =
  590. _mm256_unpacklo_epi32(sum_odd_oh1, sum_even_oh1);
  591. __m256i sum_even_1 =
  592. _mm256_unpackhi_epi32(sum_odd_oh1, sum_even_oh1);
  593. sum_left = _mm256_permute2f128_si256(sum_odd_1, sum_even_1, 32);
  594. sum_right = _mm256_permute2f128_si256(sum_odd_1, sum_even_1, 49);
  595. sum_left = _mm256_add_epi32(sum_left, bias_val);
  596. sum_right = _mm256_add_epi32(sum_right, bias_val);
  597. if (is_quantized) {
  598. op({{sum_left, sum_right}}, reinterpret_cast<dt_qint8*>(dst1));
  599. } else {
  600. _mm256_storeu_si256((__m256i*)(out_ptr1), sum_left);
  601. _mm256_storeu_si256((__m256i*)(out_ptr1 + 8), sum_right);
  602. }
  603. r0 += 16;
  604. r1 += 16;
  605. r2 += 16;
  606. r3 += 16;
  607. r4 += 16;
  608. r5 += 16;
  609. dst0 += 16;
  610. dst1 += 16;
  611. out_ptr0 += 16;
  612. out_ptr1 += 16;
  613. }
  614. r0 += tail_step + IW;
  615. r1 += tail_step + IW;
  616. r2 += tail_step + IW;
  617. r3 += tail_step + IW;
  618. r4 += tail_step + IW;
  619. r5 += tail_step + IW;
  620. dst0 += OW;
  621. dst1 += OW;
  622. out_ptr0 += OW;
  623. out_ptr1 += OW;
  624. }
  625. for (; h < OH; h++) {
  626. size_t w = 0;
  627. for (; w < width; w++) {
  628. UNROLL_CALL0(5, load_src0)
  629. UNROLL_CALL0(5, load_src1)
  630. UNROLL_CALL0(5, load_src2)
  631. UNROLL_CALL0(5, load_src3)
  632. UNROLL_CALL0(5, load_src4)
  633. UNROLL_CALL0(5, load_src5)
  634. __m256i sum0_odd, sum0_even, sum1_odd, sum1_even, sum2_odd,
  635. sum2_even, sum3_odd, sum3_even, sum4_odd, sum4_even;
  636. //! cal src0
  637. __m256i dot1, dot2, dot3;
  638. dot1 = _mm256_madd_epi16(cvt16_src00, filter_01);
  639. dot2 = _mm256_madd_epi16(cvt16_src02, filter_23);
  640. dot3 = _mm256_madd_epi16(cvt16_src04, filter_40);
  641. sum0_odd = _mm256_add_epi32(dot1, dot2);
  642. sum0_odd = _mm256_add_epi32(sum0_odd, dot3);
  643. dot1 = _mm256_madd_epi16(cvt16_src01, filter_01);
  644. dot2 = _mm256_madd_epi16(cvt16_src03, filter_23);
  645. dot3 = _mm256_madd_epi16(cvt16_src05, filter_40);
  646. sum0_even = _mm256_add_epi32(dot1, dot2);
  647. sum0_even = _mm256_add_epi32(sum0_even, dot3);
  648. //! cal src1
  649. dot1 = _mm256_madd_epi16(cvt16_src10, filter_56);
  650. dot2 = _mm256_madd_epi16(cvt16_src12, filter_78);
  651. dot3 = _mm256_madd_epi16(cvt16_src14, filter_90);
  652. sum1_odd = _mm256_add_epi32(dot1, dot2);
  653. sum1_odd = _mm256_add_epi32(sum1_odd, dot3);
  654. dot1 = _mm256_madd_epi16(cvt16_src11, filter_56);
  655. dot2 = _mm256_madd_epi16(cvt16_src13, filter_78);
  656. dot3 = _mm256_madd_epi16(cvt16_src15, filter_90);
  657. sum1_even = _mm256_add_epi32(dot1, dot2);
  658. sum1_even = _mm256_add_epi32(sum1_even, dot3);
  659. //! cal src2
  660. dot1 = _mm256_madd_epi16(cvt16_src20, filter_1011);
  661. dot2 = _mm256_madd_epi16(cvt16_src22, filter_1213);
  662. dot3 = _mm256_madd_epi16(cvt16_src24, filter_140);
  663. sum2_odd = _mm256_add_epi32(dot1, dot2);
  664. sum2_odd = _mm256_add_epi32(sum2_odd, dot3);
  665. dot1 = _mm256_madd_epi16(cvt16_src21, filter_1011);
  666. dot2 = _mm256_madd_epi16(cvt16_src23, filter_1213);
  667. dot3 = _mm256_madd_epi16(cvt16_src25, filter_140);
  668. sum2_even = _mm256_add_epi32(dot1, dot2);
  669. sum2_even = _mm256_add_epi32(sum2_even, dot3);
  670. //! cal src3
  671. dot1 = _mm256_madd_epi16(cvt16_src30, filter_1516);
  672. dot2 = _mm256_madd_epi16(cvt16_src32, filter_1718);
  673. dot3 = _mm256_madd_epi16(cvt16_src34, filter_190);
  674. sum3_odd = _mm256_add_epi32(dot1, dot2);
  675. sum3_odd = _mm256_add_epi32(sum3_odd, dot3);
  676. dot1 = _mm256_madd_epi16(cvt16_src31, filter_1516);
  677. dot2 = _mm256_madd_epi16(cvt16_src33, filter_1718);
  678. dot3 = _mm256_madd_epi16(cvt16_src35, filter_190);
  679. sum3_even = _mm256_add_epi32(dot1, dot2);
  680. sum3_even = _mm256_add_epi32(sum3_even, dot3);
  681. //! cal src4
  682. dot1 = _mm256_madd_epi16(cvt16_src40, filter_2021);
  683. dot2 = _mm256_madd_epi16(cvt16_src42, filter_2223);
  684. dot3 = _mm256_madd_epi16(cvt16_src44, filter_240);
  685. sum4_odd = _mm256_add_epi32(dot1, dot2);
  686. sum4_odd = _mm256_add_epi32(sum4_odd, dot3);
  687. dot1 = _mm256_madd_epi16(cvt16_src41, filter_2021);
  688. dot2 = _mm256_madd_epi16(cvt16_src43, filter_2223);
  689. dot3 = _mm256_madd_epi16(cvt16_src45, filter_240);
  690. sum4_even = _mm256_add_epi32(dot1, dot2);
  691. sum4_even = _mm256_add_epi32(sum4_even, dot3);
  692. __m256i sum_odd, sum_even;
  693. sum_odd = _mm256_add_epi32(sum0_odd, sum1_odd);
  694. sum_odd = _mm256_add_epi32(sum_odd, sum2_odd);
  695. sum_odd = _mm256_add_epi32(sum_odd, sum3_odd);
  696. sum_odd = _mm256_add_epi32(sum_odd, sum4_odd);
  697. sum_even = _mm256_add_epi32(sum0_even, sum1_even);
  698. sum_even = _mm256_add_epi32(sum_even, sum2_even);
  699. sum_even = _mm256_add_epi32(sum_even, sum3_even);
  700. sum_even = _mm256_add_epi32(sum_even, sum4_even);
  701. __m256i sum_odd_0 = _mm256_unpacklo_epi32(sum_odd, sum_even);
  702. __m256i sum_even_0 = _mm256_unpackhi_epi32(sum_odd, sum_even);
  703. __m256i sum_left =
  704. _mm256_permute2f128_si256(sum_odd_0, sum_even_0, 32);
  705. __m256i sum_right =
  706. _mm256_permute2f128_si256(sum_odd_0, sum_even_0, 49);
  707. sum_left = _mm256_add_epi32(sum_left, bias_val);
  708. sum_right = _mm256_add_epi32(sum_right, bias_val);
  709. if (is_quantized) {
  710. op({{sum_left, sum_right}}, reinterpret_cast<dt_qint8*>(dst0));
  711. } else {
  712. _mm256_storeu_si256((__m256i*)(out_ptr0), sum_left);
  713. _mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right);
  714. }
  715. r0 += 16;
  716. r1 += 16;
  717. r2 += 16;
  718. r3 += 16;
  719. r4 += 16;
  720. dst0 += 16;
  721. out_ptr0 += 16;
  722. }
  723. r0 += tail_step;
  724. r1 += tail_step;
  725. r2 += tail_step;
  726. r3 += tail_step;
  727. r4 += tail_step;
  728. }
  729. }
  730. template <BiasMode bias_mode, bool is_quantized, typename Op>
  731. void avx2_chanwise_direct_stride1_7x7_int8(const int8_t* src,
  732. const int8_t* filter,
  733. const int32_t* bias, int32_t* temp,
  734. int8_t* dst, const size_t IH,
  735. const size_t IW, const size_t OH,
  736. const size_t OW, const Op& op) {
  737. MEGDNN_MARK_USED_VAR(IH);
  738. size_t tail_step = IW - OW;
  739. int8_t* dst0 = dst;
  740. int8_t* dst1 = dst + OW;
  741. int32_t* out_ptr0 = temp;
  742. int32_t* out_ptr1 = temp + OW;
  743. const int8_t* r0 = src;
  744. const int8_t* r1 = src + IW;
  745. const int8_t* r2 = src + 2 * IW;
  746. const int8_t* r3 = src + 3 * IW;
  747. const int8_t* r4 = src + 4 * IW;
  748. const int8_t* r5 = src + 5 * IW;
  749. const int8_t* r6 = src + 6 * IW;
  750. const int8_t* r7 = src + 7 * IW;
  751. uint8_t fill_zero = 0;
  752. UNROLL_CALL0(49, load_filter)
  753. __m128i k_fill = _mm_set1_epi8(fill_zero);
  754. __m128i k01 = _mm_unpacklo_epi8(k_0, k_1);
  755. __m128i k23 = _mm_unpacklo_epi8(k_2, k_3);
  756. __m128i k45 = _mm_unpacklo_epi8(k_4, k_5);
  757. __m128i k60 = _mm_unpacklo_epi8(k_6, k_fill);
  758. __m128i k78 = _mm_unpacklo_epi8(k_7, k_8);
  759. __m128i k910 = _mm_unpacklo_epi8(k_9, k_10);
  760. __m128i k1112 = _mm_unpacklo_epi8(k_11, k_12);
  761. __m128i k130 = _mm_unpacklo_epi8(k_13, k_fill);
  762. __m128i k1415 = _mm_unpacklo_epi8(k_14, k_15);
  763. __m128i k1617 = _mm_unpacklo_epi8(k_16, k_17);
  764. __m128i k1819 = _mm_unpacklo_epi8(k_18, k_19);
  765. __m128i k200 = _mm_unpacklo_epi8(k_20, k_fill);
  766. __m128i k2122 = _mm_unpacklo_epi8(k_21, k_22);
  767. __m128i k2324 = _mm_unpacklo_epi8(k_23, k_24);
  768. __m128i k2526 = _mm_unpacklo_epi8(k_25, k_26);
  769. __m128i k270 = _mm_unpacklo_epi8(k_27, k_fill);
  770. __m128i k2829 = _mm_unpacklo_epi8(k_28, k_29);
  771. __m128i k3031 = _mm_unpacklo_epi8(k_30, k_31);
  772. __m128i k3233 = _mm_unpacklo_epi8(k_32, k_33);
  773. __m128i k340 = _mm_unpacklo_epi8(k_34, k_fill);
  774. __m128i k3536 = _mm_unpacklo_epi8(k_35, k_36);
  775. __m128i k3738 = _mm_unpacklo_epi8(k_37, k_38);
  776. __m128i k3940 = _mm_unpacklo_epi8(k_39, k_40);
  777. __m128i k410 = _mm_unpacklo_epi8(k_41, k_fill);
  778. __m128i k4243 = _mm_unpacklo_epi8(k_42, k_43);
  779. __m128i k4445 = _mm_unpacklo_epi8(k_44, k_45);
  780. __m128i k4647 = _mm_unpacklo_epi8(k_46, k_47);
  781. __m128i k480 = _mm_unpacklo_epi8(k_48, k_fill);
  782. __m256i bias_val;
  783. //! load bias
  784. if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
  785. bias_val = _mm256_set1_epi32(*(bias));
  786. } else {
  787. bias_val = _mm256_set1_epi32(0);
  788. }
  789. //! cvt i8 --> i16
  790. __m256i filter_01 = _mm256_cvtepi8_epi16(k01);
  791. __m256i filter_23 = _mm256_cvtepi8_epi16(k23);
  792. __m256i filter_45 = _mm256_cvtepi8_epi16(k45);
  793. __m256i filter_60 = _mm256_cvtepi8_epi16(k60);
  794. __m256i filter_78 = _mm256_cvtepi8_epi16(k78);
  795. __m256i filter_910 = _mm256_cvtepi8_epi16(k910);
  796. __m256i filter_1112 = _mm256_cvtepi8_epi16(k1112);
  797. __m256i filter_130 = _mm256_cvtepi8_epi16(k130);
  798. __m256i filter_1415 = _mm256_cvtepi8_epi16(k1415);
  799. __m256i filter_1617 = _mm256_cvtepi8_epi16(k1617);
  800. __m256i filter_1819 = _mm256_cvtepi8_epi16(k1819);
  801. __m256i filter_200 = _mm256_cvtepi8_epi16(k200);
  802. __m256i filter_2122 = _mm256_cvtepi8_epi16(k2122);
  803. __m256i filter_2324 = _mm256_cvtepi8_epi16(k2324);
  804. __m256i filter_2526 = _mm256_cvtepi8_epi16(k2526);
  805. __m256i filter_270 = _mm256_cvtepi8_epi16(k270);
  806. __m256i filter_2829 = _mm256_cvtepi8_epi16(k2829);
  807. __m256i filter_3031 = _mm256_cvtepi8_epi16(k3031);
  808. __m256i filter_3233 = _mm256_cvtepi8_epi16(k3233);
  809. __m256i filter_340 = _mm256_cvtepi8_epi16(k340);
  810. __m256i filter_3536 = _mm256_cvtepi8_epi16(k3536);
  811. __m256i filter_3738 = _mm256_cvtepi8_epi16(k3738);
  812. __m256i filter_3940 = _mm256_cvtepi8_epi16(k3940);
  813. __m256i filter_410 = _mm256_cvtepi8_epi16(k410);
  814. __m256i filter_4243 = _mm256_cvtepi8_epi16(k4243);
  815. __m256i filter_4445 = _mm256_cvtepi8_epi16(k4445);
  816. __m256i filter_4647 = _mm256_cvtepi8_epi16(k4647);
  817. __m256i filter_480 = _mm256_cvtepi8_epi16(k480);
  818. size_t width = OW >> 4;
  819. size_t h = 0;
  820. for (; h + 1 < OH; h += 2) {
  821. size_t w = 0;
  822. for (; w < width; w++) {
  823. UNROLL_CALL0(8, load_src0)
  824. UNROLL_CALL0(8, load_src1)
  825. UNROLL_CALL0(8, load_src2)
  826. UNROLL_CALL0(8, load_src3)
  827. UNROLL_CALL0(8, load_src4)
  828. UNROLL_CALL0(8, load_src5)
  829. UNROLL_CALL0(8, load_src6)
  830. UNROLL_CALL0(8, load_src7)
  831. __m256i sum0_odd, sum0_even, sum1_odd, sum1_even, sum2_odd,
  832. sum2_even, sum3_odd, sum3_even, sum4_odd, sum4_even,
  833. sum5_odd, sum5_even, sum6_odd, sum6_even;
  834. __m256i sum10_odd, sum10_even, sum20_odd, sum20_even, sum30_odd,
  835. sum30_even, sum40_odd, sum40_even, sum50_odd, sum50_even,
  836. sum60_odd, sum60_even, sum7_odd, sum7_even;
  837. //! cal src0
  838. __m256i dot1, dot2, dot3, dot4;
  839. dot1 = _mm256_madd_epi16(cvt16_src00, filter_01);
  840. dot2 = _mm256_madd_epi16(cvt16_src02, filter_23);
  841. dot3 = _mm256_madd_epi16(cvt16_src04, filter_45);
  842. dot4 = _mm256_madd_epi16(cvt16_src06, filter_60);
  843. sum0_odd = _mm256_add_epi32(dot1, dot2);
  844. sum0_odd = _mm256_add_epi32(sum0_odd, dot3);
  845. sum0_odd = _mm256_add_epi32(sum0_odd, dot4);
  846. dot1 = _mm256_madd_epi16(cvt16_src01, filter_01);
  847. dot2 = _mm256_madd_epi16(cvt16_src03, filter_23);
  848. dot3 = _mm256_madd_epi16(cvt16_src05, filter_45);
  849. dot4 = _mm256_madd_epi16(cvt16_src07, filter_60);
  850. sum0_even = _mm256_add_epi32(dot1, dot2);
  851. sum0_even = _mm256_add_epi32(sum0_even, dot3);
  852. sum0_even = _mm256_add_epi32(sum0_even, dot4);
  853. //! cal src1
  854. dot1 = _mm256_madd_epi16(cvt16_src10, filter_78);
  855. dot2 = _mm256_madd_epi16(cvt16_src12, filter_910);
  856. dot3 = _mm256_madd_epi16(cvt16_src14, filter_1112);
  857. dot4 = _mm256_madd_epi16(cvt16_src16, filter_130);
  858. sum1_odd = _mm256_add_epi32(dot1, dot2);
  859. sum1_odd = _mm256_add_epi32(sum1_odd, dot3);
  860. sum1_odd = _mm256_add_epi32(sum1_odd, dot4);
  861. dot1 = _mm256_madd_epi16(cvt16_src11, filter_78);
  862. dot2 = _mm256_madd_epi16(cvt16_src13, filter_910);
  863. dot3 = _mm256_madd_epi16(cvt16_src15, filter_1112);
  864. dot4 = _mm256_madd_epi16(cvt16_src17, filter_130);
  865. sum1_even = _mm256_add_epi32(dot1, dot2);
  866. sum1_even = _mm256_add_epi32(sum1_even, dot3);
  867. sum1_even = _mm256_add_epi32(sum1_even, dot4);
  868. dot1 = _mm256_madd_epi16(cvt16_src10, filter_01);
  869. dot2 = _mm256_madd_epi16(cvt16_src12, filter_23);
  870. dot3 = _mm256_madd_epi16(cvt16_src14, filter_45);
  871. dot4 = _mm256_madd_epi16(cvt16_src16, filter_60);
  872. sum10_odd = _mm256_add_epi32(dot1, dot2);
  873. sum10_odd = _mm256_add_epi32(sum10_odd, dot3);
  874. sum10_odd = _mm256_add_epi32(sum10_odd, dot4);
  875. dot1 = _mm256_madd_epi16(cvt16_src11, filter_01);
  876. dot2 = _mm256_madd_epi16(cvt16_src13, filter_23);
  877. dot3 = _mm256_madd_epi16(cvt16_src15, filter_45);
  878. dot4 = _mm256_madd_epi16(cvt16_src17, filter_60);
  879. sum10_even = _mm256_add_epi32(dot1, dot2);
  880. sum10_even = _mm256_add_epi32(sum10_even, dot3);
  881. sum10_even = _mm256_add_epi32(sum10_even, dot4);
  882. //! cal src2
  883. dot1 = _mm256_madd_epi16(cvt16_src20, filter_1415);
  884. dot2 = _mm256_madd_epi16(cvt16_src22, filter_1617);
  885. dot3 = _mm256_madd_epi16(cvt16_src24, filter_1819);
  886. dot4 = _mm256_madd_epi16(cvt16_src26, filter_200);
  887. sum2_odd = _mm256_add_epi32(dot1, dot2);
  888. sum2_odd = _mm256_add_epi32(sum2_odd, dot3);
  889. sum2_odd = _mm256_add_epi32(sum2_odd, dot4);
  890. dot1 = _mm256_madd_epi16(cvt16_src21, filter_1415);
  891. dot2 = _mm256_madd_epi16(cvt16_src23, filter_1617);
  892. dot3 = _mm256_madd_epi16(cvt16_src25, filter_1819);
  893. dot4 = _mm256_madd_epi16(cvt16_src27, filter_200);
  894. sum2_even = _mm256_add_epi32(dot1, dot2);
  895. sum2_even = _mm256_add_epi32(sum2_even, dot3);
  896. sum2_even = _mm256_add_epi32(sum2_even, dot4);
  897. dot1 = _mm256_madd_epi16(cvt16_src20, filter_78);
  898. dot2 = _mm256_madd_epi16(cvt16_src22, filter_910);
  899. dot3 = _mm256_madd_epi16(cvt16_src24, filter_1112);
  900. dot4 = _mm256_madd_epi16(cvt16_src26, filter_130);
  901. sum20_odd = _mm256_add_epi32(dot1, dot2);
  902. sum20_odd = _mm256_add_epi32(sum20_odd, dot3);
  903. sum20_odd = _mm256_add_epi32(sum20_odd, dot4);
  904. dot1 = _mm256_madd_epi16(cvt16_src21, filter_78);
  905. dot2 = _mm256_madd_epi16(cvt16_src23, filter_910);
  906. dot3 = _mm256_madd_epi16(cvt16_src25, filter_1112);
  907. dot4 = _mm256_madd_epi16(cvt16_src27, filter_130);
  908. sum20_even = _mm256_add_epi32(dot1, dot2);
  909. sum20_even = _mm256_add_epi32(sum20_even, dot3);
  910. sum20_even = _mm256_add_epi32(sum20_even, dot4);
  911. //! cal src3
  912. dot1 = _mm256_madd_epi16(cvt16_src30, filter_2122);
  913. dot2 = _mm256_madd_epi16(cvt16_src32, filter_2324);
  914. dot3 = _mm256_madd_epi16(cvt16_src34, filter_2526);
  915. dot4 = _mm256_madd_epi16(cvt16_src36, filter_270);
  916. sum3_odd = _mm256_add_epi32(dot1, dot2);
  917. sum3_odd = _mm256_add_epi32(sum3_odd, dot3);
  918. sum3_odd = _mm256_add_epi32(sum3_odd, dot4);
  919. dot1 = _mm256_madd_epi16(cvt16_src31, filter_2122);
  920. dot2 = _mm256_madd_epi16(cvt16_src33, filter_2324);
  921. dot3 = _mm256_madd_epi16(cvt16_src35, filter_2526);
  922. dot4 = _mm256_madd_epi16(cvt16_src37, filter_270);
  923. sum3_even = _mm256_add_epi32(dot1, dot2);
  924. sum3_even = _mm256_add_epi32(sum3_even, dot3);
  925. sum3_even = _mm256_add_epi32(sum3_even, dot4);
  926. dot1 = _mm256_madd_epi16(cvt16_src30, filter_1415);
  927. dot2 = _mm256_madd_epi16(cvt16_src32, filter_1617);
  928. dot3 = _mm256_madd_epi16(cvt16_src34, filter_1819);
  929. dot4 = _mm256_madd_epi16(cvt16_src36, filter_200);
  930. sum30_odd = _mm256_add_epi32(dot1, dot2);
  931. sum30_odd = _mm256_add_epi32(sum30_odd, dot3);
  932. sum30_odd = _mm256_add_epi32(sum30_odd, dot4);
  933. dot1 = _mm256_madd_epi16(cvt16_src31, filter_1415);
  934. dot2 = _mm256_madd_epi16(cvt16_src33, filter_1617);
  935. dot3 = _mm256_madd_epi16(cvt16_src35, filter_1819);
  936. dot4 = _mm256_madd_epi16(cvt16_src37, filter_200);
  937. sum30_even = _mm256_add_epi32(dot1, dot2);
  938. sum30_even = _mm256_add_epi32(sum30_even, dot3);
  939. sum30_even = _mm256_add_epi32(sum30_even, dot4);
  940. //! cal src4
  941. dot1 = _mm256_madd_epi16(cvt16_src40, filter_2829);
  942. dot2 = _mm256_madd_epi16(cvt16_src42, filter_3031);
  943. dot3 = _mm256_madd_epi16(cvt16_src44, filter_3233);
  944. dot4 = _mm256_madd_epi16(cvt16_src46, filter_340);
  945. sum4_odd = _mm256_add_epi32(dot1, dot2);
  946. sum4_odd = _mm256_add_epi32(sum4_odd, dot3);
  947. sum4_odd = _mm256_add_epi32(sum4_odd, dot4);
  948. dot1 = _mm256_madd_epi16(cvt16_src41, filter_2829);
  949. dot2 = _mm256_madd_epi16(cvt16_src43, filter_3031);
  950. dot3 = _mm256_madd_epi16(cvt16_src45, filter_3233);
  951. dot4 = _mm256_madd_epi16(cvt16_src47, filter_340);
  952. sum4_even = _mm256_add_epi32(dot1, dot2);
  953. sum4_even = _mm256_add_epi32(sum4_even, dot3);
  954. sum4_even = _mm256_add_epi32(sum4_even, dot4);
  955. dot1 = _mm256_madd_epi16(cvt16_src40, filter_2122);
  956. dot2 = _mm256_madd_epi16(cvt16_src42, filter_2324);
  957. dot3 = _mm256_madd_epi16(cvt16_src44, filter_2526);
  958. dot4 = _mm256_madd_epi16(cvt16_src46, filter_270);
  959. sum40_odd = _mm256_add_epi32(dot1, dot2);
  960. sum40_odd = _mm256_add_epi32(sum40_odd, dot3);
  961. sum40_odd = _mm256_add_epi32(sum40_odd, dot4);
  962. dot1 = _mm256_madd_epi16(cvt16_src41, filter_2122);
  963. dot2 = _mm256_madd_epi16(cvt16_src43, filter_2324);
  964. dot3 = _mm256_madd_epi16(cvt16_src45, filter_2526);
  965. dot4 = _mm256_madd_epi16(cvt16_src47, filter_270);
  966. sum40_even = _mm256_add_epi32(dot1, dot2);
  967. sum40_even = _mm256_add_epi32(sum40_even, dot3);
  968. sum40_even = _mm256_add_epi32(sum40_even, dot4);
  969. //! cal src5
  970. dot1 = _mm256_madd_epi16(cvt16_src50, filter_3536);
  971. dot2 = _mm256_madd_epi16(cvt16_src52, filter_3738);
  972. dot3 = _mm256_madd_epi16(cvt16_src54, filter_3940);
  973. dot4 = _mm256_madd_epi16(cvt16_src56, filter_410);
  974. sum5_odd = _mm256_add_epi32(dot1, dot2);
  975. sum5_odd = _mm256_add_epi32(sum5_odd, dot3);
  976. sum5_odd = _mm256_add_epi32(sum5_odd, dot4);
  977. dot1 = _mm256_madd_epi16(cvt16_src51, filter_3536);
  978. dot2 = _mm256_madd_epi16(cvt16_src53, filter_3738);
  979. dot3 = _mm256_madd_epi16(cvt16_src55, filter_3940);
  980. dot4 = _mm256_madd_epi16(cvt16_src57, filter_410);
  981. sum5_even = _mm256_add_epi32(dot1, dot2);
  982. sum5_even = _mm256_add_epi32(sum5_even, dot3);
  983. sum5_even = _mm256_add_epi32(sum5_even, dot4);
  984. dot1 = _mm256_madd_epi16(cvt16_src50, filter_2829);
  985. dot2 = _mm256_madd_epi16(cvt16_src52, filter_3031);
  986. dot3 = _mm256_madd_epi16(cvt16_src54, filter_3233);
  987. dot4 = _mm256_madd_epi16(cvt16_src56, filter_340);
  988. sum50_odd = _mm256_add_epi32(dot1, dot2);
  989. sum50_odd = _mm256_add_epi32(sum50_odd, dot3);
  990. sum50_odd = _mm256_add_epi32(sum50_odd, dot4);
  991. dot1 = _mm256_madd_epi16(cvt16_src51, filter_2829);
  992. dot2 = _mm256_madd_epi16(cvt16_src53, filter_3031);
  993. dot3 = _mm256_madd_epi16(cvt16_src55, filter_3233);
  994. dot4 = _mm256_madd_epi16(cvt16_src57, filter_340);
  995. sum50_even = _mm256_add_epi32(dot1, dot2);
  996. sum50_even = _mm256_add_epi32(sum50_even, dot3);
  997. sum50_even = _mm256_add_epi32(sum50_even, dot4);
  998. //! cal src6
  999. dot1 = _mm256_madd_epi16(cvt16_src60, filter_4243);
  1000. dot2 = _mm256_madd_epi16(cvt16_src62, filter_4445);
  1001. dot3 = _mm256_madd_epi16(cvt16_src64, filter_4647);
  1002. dot4 = _mm256_madd_epi16(cvt16_src66, filter_480);
  1003. sum6_odd = _mm256_add_epi32(dot1, dot2);
  1004. sum6_odd = _mm256_add_epi32(sum6_odd, dot3);
  1005. sum6_odd = _mm256_add_epi32(sum6_odd, dot4);
  1006. dot1 = _mm256_madd_epi16(cvt16_src61, filter_4243);
  1007. dot2 = _mm256_madd_epi16(cvt16_src63, filter_4445);
  1008. dot3 = _mm256_madd_epi16(cvt16_src65, filter_4647);
  1009. dot4 = _mm256_madd_epi16(cvt16_src67, filter_480);
  1010. sum6_even = _mm256_add_epi32(dot1, dot2);
  1011. sum6_even = _mm256_add_epi32(sum6_even, dot3);
  1012. sum6_even = _mm256_add_epi32(sum6_even, dot4);
  1013. dot1 = _mm256_madd_epi16(cvt16_src60, filter_3536);
  1014. dot2 = _mm256_madd_epi16(cvt16_src62, filter_3738);
  1015. dot3 = _mm256_madd_epi16(cvt16_src64, filter_3940);
  1016. dot4 = _mm256_madd_epi16(cvt16_src66, filter_410);
  1017. sum60_odd = _mm256_add_epi32(dot1, dot2);
  1018. sum60_odd = _mm256_add_epi32(sum60_odd, dot3);
  1019. sum60_odd = _mm256_add_epi32(sum60_odd, dot4);
  1020. dot1 = _mm256_madd_epi16(cvt16_src61, filter_3536);
  1021. dot2 = _mm256_madd_epi16(cvt16_src63, filter_3738);
  1022. dot3 = _mm256_madd_epi16(cvt16_src65, filter_3940);
  1023. dot4 = _mm256_madd_epi16(cvt16_src67, filter_410);
  1024. sum60_even = _mm256_add_epi32(dot1, dot2);
  1025. sum60_even = _mm256_add_epi32(sum60_even, dot3);
  1026. sum60_even = _mm256_add_epi32(sum60_even, dot4);
  1027. dot1 = _mm256_madd_epi16(cvt16_src70, filter_4243);
  1028. dot2 = _mm256_madd_epi16(cvt16_src72, filter_4445);
  1029. dot3 = _mm256_madd_epi16(cvt16_src74, filter_4647);
  1030. dot4 = _mm256_madd_epi16(cvt16_src76, filter_480);
  1031. sum7_odd = _mm256_add_epi32(dot1, dot2);
  1032. sum7_odd = _mm256_add_epi32(sum7_odd, dot3);
  1033. sum7_odd = _mm256_add_epi32(sum7_odd, dot4);
  1034. dot1 = _mm256_madd_epi16(cvt16_src71, filter_4243);
  1035. dot2 = _mm256_madd_epi16(cvt16_src73, filter_4445);
  1036. dot3 = _mm256_madd_epi16(cvt16_src75, filter_4647);
  1037. dot4 = _mm256_madd_epi16(cvt16_src77, filter_480);
  1038. sum7_even = _mm256_add_epi32(dot1, dot2);
  1039. sum7_even = _mm256_add_epi32(sum7_even, dot3);
  1040. sum7_even = _mm256_add_epi32(sum7_even, dot4);
  1041. __m256i sum_odd, sum_even;
  1042. //! add src0 ~ src6
  1043. sum_odd = _mm256_add_epi32(sum0_odd, sum1_odd);
  1044. sum_odd = _mm256_add_epi32(sum_odd, sum2_odd);
  1045. sum_odd = _mm256_add_epi32(sum_odd, sum3_odd);
  1046. sum_odd = _mm256_add_epi32(sum_odd, sum4_odd);
  1047. sum_odd = _mm256_add_epi32(sum_odd, sum5_odd);
  1048. sum_odd = _mm256_add_epi32(sum_odd, sum6_odd);
  1049. sum_even = _mm256_add_epi32(sum0_even, sum1_even);
  1050. sum_even = _mm256_add_epi32(sum_even, sum2_even);
  1051. sum_even = _mm256_add_epi32(sum_even, sum3_even);
  1052. sum_even = _mm256_add_epi32(sum_even, sum4_even);
  1053. sum_even = _mm256_add_epi32(sum_even, sum5_even);
  1054. sum_even = _mm256_add_epi32(sum_even, sum6_even);
  1055. __m256i sum_odd_0 = _mm256_unpacklo_epi32(sum_odd, sum_even);
  1056. __m256i sum_even_0 = _mm256_unpackhi_epi32(sum_odd, sum_even);
  1057. __m256i sum_left =
  1058. _mm256_permute2f128_si256(sum_odd_0, sum_even_0, 32);
  1059. __m256i sum_right =
  1060. _mm256_permute2f128_si256(sum_odd_0, sum_even_0, 49);
  1061. sum_left = _mm256_add_epi32(sum_left, bias_val);
  1062. sum_right = _mm256_add_epi32(sum_right, bias_val);
  1063. if (is_quantized) {
  1064. op({{sum_left, sum_right}}, reinterpret_cast<dt_qint8*>(dst0));
  1065. } else {
  1066. _mm256_storeu_si256((__m256i*)(out_ptr0), sum_left);
  1067. _mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right);
  1068. }
  1069. __m256i sum_odd_oh1, sum_even_oh1;
  1070. //! add src1 ~ src7
  1071. sum_odd_oh1 = _mm256_add_epi32(sum10_odd, sum20_odd);
  1072. sum_odd_oh1 = _mm256_add_epi32(sum_odd_oh1, sum30_odd);
  1073. sum_odd_oh1 = _mm256_add_epi32(sum_odd_oh1, sum40_odd);
  1074. sum_odd_oh1 = _mm256_add_epi32(sum_odd_oh1, sum50_odd);
  1075. sum_odd_oh1 = _mm256_add_epi32(sum_odd_oh1, sum60_odd);
  1076. sum_odd_oh1 = _mm256_add_epi32(sum_odd_oh1, sum7_odd);
  1077. sum_even_oh1 = _mm256_add_epi32(sum10_even, sum20_even);
  1078. sum_even_oh1 = _mm256_add_epi32(sum_even_oh1, sum30_even);
  1079. sum_even_oh1 = _mm256_add_epi32(sum_even_oh1, sum40_even);
  1080. sum_even_oh1 = _mm256_add_epi32(sum_even_oh1, sum50_even);
  1081. sum_even_oh1 = _mm256_add_epi32(sum_even_oh1, sum60_even);
  1082. sum_even_oh1 = _mm256_add_epi32(sum_even_oh1, sum7_even);
  1083. __m256i sum_odd_1 =
  1084. _mm256_unpacklo_epi32(sum_odd_oh1, sum_even_oh1);
  1085. __m256i sum_even_1 =
  1086. _mm256_unpackhi_epi32(sum_odd_oh1, sum_even_oh1);
  1087. sum_left = _mm256_permute2f128_si256(sum_odd_1, sum_even_1, 32);
  1088. sum_right = _mm256_permute2f128_si256(sum_odd_1, sum_even_1, 49);
  1089. sum_left = _mm256_add_epi32(sum_left, bias_val);
  1090. sum_right = _mm256_add_epi32(sum_right, bias_val);
  1091. if (is_quantized) {
  1092. op({{sum_left, sum_right}}, reinterpret_cast<dt_qint8*>(dst1));
  1093. } else {
  1094. _mm256_storeu_si256((__m256i*)(out_ptr1), sum_left);
  1095. _mm256_storeu_si256((__m256i*)(out_ptr1 + 8), sum_right);
  1096. }
  1097. r0 += 16;
  1098. r1 += 16;
  1099. r2 += 16;
  1100. r3 += 16;
  1101. r4 += 16;
  1102. r5 += 16;
  1103. r6 += 16;
  1104. r7 += 16;
  1105. dst0 += 16;
  1106. dst1 += 16;
  1107. out_ptr0 += 16;
  1108. out_ptr1 += 16;
  1109. }
  1110. r0 += tail_step + IW;
  1111. r1 += tail_step + IW;
  1112. r2 += tail_step + IW;
  1113. r3 += tail_step + IW;
  1114. r4 += tail_step + IW;
  1115. r5 += tail_step + IW;
  1116. r6 += tail_step + IW;
  1117. r7 += tail_step + IW;
  1118. dst0 += OW;
  1119. dst1 += OW;
  1120. out_ptr0 += OW;
  1121. out_ptr1 += OW;
  1122. }
  1123. for (; h < OH; h++) {
  1124. size_t w = 0;
  1125. for (; w < width; w++) {
  1126. UNROLL_CALL0(7, load_src0)
  1127. UNROLL_CALL0(7, load_src1)
  1128. UNROLL_CALL0(7, load_src2)
  1129. UNROLL_CALL0(7, load_src3)
  1130. UNROLL_CALL0(7, load_src4)
  1131. UNROLL_CALL0(7, load_src5)
  1132. UNROLL_CALL0(7, load_src6)
  1133. UNROLL_CALL0(7, load_src7)
  1134. __m256i sum0_odd, sum0_even, sum1_odd, sum1_even, sum2_odd,
  1135. sum2_even, sum3_odd, sum3_even, sum4_odd, sum4_even,
  1136. sum5_odd, sum5_even, sum6_odd, sum6_even;
  1137. //! cal src0
  1138. __m256i dot1, dot2, dot3, dot4;
  1139. dot1 = _mm256_madd_epi16(cvt16_src00, filter_01);
  1140. dot2 = _mm256_madd_epi16(cvt16_src02, filter_23);
  1141. dot3 = _mm256_madd_epi16(cvt16_src04, filter_45);
  1142. dot4 = _mm256_madd_epi16(cvt16_src06, filter_60);
  1143. sum0_odd = _mm256_add_epi32(dot1, dot2);
  1144. sum0_odd = _mm256_add_epi32(sum0_odd, dot3);
  1145. sum0_odd = _mm256_add_epi32(sum0_odd, dot4);
  1146. dot1 = _mm256_madd_epi16(cvt16_src01, filter_01);
  1147. dot2 = _mm256_madd_epi16(cvt16_src03, filter_23);
  1148. dot3 = _mm256_madd_epi16(cvt16_src05, filter_45);
  1149. dot4 = _mm256_madd_epi16(cvt16_src07, filter_60);
  1150. sum0_even = _mm256_add_epi32(dot1, dot2);
  1151. sum0_even = _mm256_add_epi32(sum0_even, dot3);
  1152. sum0_even = _mm256_add_epi32(sum0_even, dot4);
  1153. //! cal src1
  1154. dot1 = _mm256_madd_epi16(cvt16_src10, filter_78);
  1155. dot2 = _mm256_madd_epi16(cvt16_src12, filter_910);
  1156. dot3 = _mm256_madd_epi16(cvt16_src14, filter_1112);
  1157. dot4 = _mm256_madd_epi16(cvt16_src16, filter_130);
  1158. sum1_odd = _mm256_add_epi32(dot1, dot2);
  1159. sum1_odd = _mm256_add_epi32(sum1_odd, dot3);
  1160. sum1_odd = _mm256_add_epi32(sum1_odd, dot4);
  1161. dot1 = _mm256_madd_epi16(cvt16_src11, filter_78);
  1162. dot2 = _mm256_madd_epi16(cvt16_src13, filter_910);
  1163. dot3 = _mm256_madd_epi16(cvt16_src15, filter_1112);
  1164. dot4 = _mm256_madd_epi16(cvt16_src17, filter_130);
  1165. sum1_even = _mm256_add_epi32(dot1, dot2);
  1166. sum1_even = _mm256_add_epi32(sum1_even, dot3);
  1167. sum1_even = _mm256_add_epi32(sum1_even, dot4);
  1168. //! cal src2
  1169. dot1 = _mm256_madd_epi16(cvt16_src20, filter_1415);
  1170. dot2 = _mm256_madd_epi16(cvt16_src22, filter_1617);
  1171. dot3 = _mm256_madd_epi16(cvt16_src24, filter_1819);
  1172. dot4 = _mm256_madd_epi16(cvt16_src26, filter_200);
  1173. sum2_odd = _mm256_add_epi32(dot1, dot2);
  1174. sum2_odd = _mm256_add_epi32(sum2_odd, dot3);
  1175. sum2_odd = _mm256_add_epi32(sum2_odd, dot4);
  1176. dot1 = _mm256_madd_epi16(cvt16_src21, filter_1415);
  1177. dot2 = _mm256_madd_epi16(cvt16_src23, filter_1617);
  1178. dot3 = _mm256_madd_epi16(cvt16_src25, filter_1819);
  1179. dot4 = _mm256_madd_epi16(cvt16_src27, filter_200);
  1180. sum2_even = _mm256_add_epi32(dot1, dot2);
  1181. sum2_even = _mm256_add_epi32(sum2_even, dot3);
  1182. sum2_even = _mm256_add_epi32(sum2_even, dot4);
  1183. //! cal src3
  1184. dot1 = _mm256_madd_epi16(cvt16_src30, filter_2122);
  1185. dot2 = _mm256_madd_epi16(cvt16_src32, filter_2324);
  1186. dot3 = _mm256_madd_epi16(cvt16_src34, filter_2526);
  1187. dot4 = _mm256_madd_epi16(cvt16_src36, filter_270);
  1188. sum3_odd = _mm256_add_epi32(dot1, dot2);
  1189. sum3_odd = _mm256_add_epi32(sum3_odd, dot3);
  1190. sum3_odd = _mm256_add_epi32(sum3_odd, dot4);
  1191. dot1 = _mm256_madd_epi16(cvt16_src31, filter_2122);
  1192. dot2 = _mm256_madd_epi16(cvt16_src33, filter_2324);
  1193. dot3 = _mm256_madd_epi16(cvt16_src35, filter_2526);
  1194. dot4 = _mm256_madd_epi16(cvt16_src37, filter_270);
  1195. sum3_even = _mm256_add_epi32(dot1, dot2);
  1196. sum3_even = _mm256_add_epi32(sum3_even, dot3);
  1197. sum3_even = _mm256_add_epi32(sum3_even, dot4);
  1198. //! cal src4
  1199. dot1 = _mm256_madd_epi16(cvt16_src40, filter_2829);
  1200. dot2 = _mm256_madd_epi16(cvt16_src42, filter_3031);
  1201. dot3 = _mm256_madd_epi16(cvt16_src44, filter_3233);
  1202. dot4 = _mm256_madd_epi16(cvt16_src46, filter_340);
  1203. sum4_odd = _mm256_add_epi32(dot1, dot2);
  1204. sum4_odd = _mm256_add_epi32(sum4_odd, dot3);
  1205. sum4_odd = _mm256_add_epi32(sum4_odd, dot4);
  1206. dot1 = _mm256_madd_epi16(cvt16_src41, filter_2829);
  1207. dot2 = _mm256_madd_epi16(cvt16_src43, filter_3031);
  1208. dot3 = _mm256_madd_epi16(cvt16_src45, filter_3233);
  1209. dot4 = _mm256_madd_epi16(cvt16_src47, filter_340);
  1210. sum4_even = _mm256_add_epi32(dot1, dot2);
  1211. sum4_even = _mm256_add_epi32(sum4_even, dot3);
  1212. sum4_even = _mm256_add_epi32(sum4_even, dot4);
  1213. //! cal src5
  1214. dot1 = _mm256_madd_epi16(cvt16_src50, filter_3536);
  1215. dot2 = _mm256_madd_epi16(cvt16_src52, filter_3738);
  1216. dot3 = _mm256_madd_epi16(cvt16_src54, filter_3940);
  1217. dot4 = _mm256_madd_epi16(cvt16_src56, filter_410);
  1218. sum5_odd = _mm256_add_epi32(dot1, dot2);
  1219. sum5_odd = _mm256_add_epi32(sum5_odd, dot3);
  1220. sum5_odd = _mm256_add_epi32(sum5_odd, dot4);
  1221. dot1 = _mm256_madd_epi16(cvt16_src51, filter_3536);
  1222. dot2 = _mm256_madd_epi16(cvt16_src53, filter_3738);
  1223. dot3 = _mm256_madd_epi16(cvt16_src55, filter_3940);
  1224. dot4 = _mm256_madd_epi16(cvt16_src57, filter_410);
  1225. sum5_even = _mm256_add_epi32(dot1, dot2);
  1226. sum5_even = _mm256_add_epi32(sum5_even, dot3);
  1227. sum5_even = _mm256_add_epi32(sum5_even, dot4);
  1228. //! cal src6
  1229. dot1 = _mm256_madd_epi16(cvt16_src60, filter_4243);
  1230. dot2 = _mm256_madd_epi16(cvt16_src62, filter_4445);
  1231. dot3 = _mm256_madd_epi16(cvt16_src64, filter_4647);
  1232. dot4 = _mm256_madd_epi16(cvt16_src66, filter_480);
  1233. sum6_odd = _mm256_add_epi32(dot1, dot2);
  1234. sum6_odd = _mm256_add_epi32(sum6_odd, dot3);
  1235. sum6_odd = _mm256_add_epi32(sum6_odd, dot4);
  1236. dot1 = _mm256_madd_epi16(cvt16_src61, filter_4243);
  1237. dot2 = _mm256_madd_epi16(cvt16_src63, filter_4445);
  1238. dot3 = _mm256_madd_epi16(cvt16_src65, filter_4647);
  1239. dot4 = _mm256_madd_epi16(cvt16_src67, filter_480);
  1240. sum6_even = _mm256_add_epi32(dot1, dot2);
  1241. sum6_even = _mm256_add_epi32(sum6_even, dot3);
  1242. sum6_even = _mm256_add_epi32(sum6_even, dot4);
  1243. __m256i sum_odd, sum_even;
  1244. //! add src0 ~ src6
  1245. sum_odd = _mm256_add_epi32(sum0_odd, sum1_odd);
  1246. sum_odd = _mm256_add_epi32(sum_odd, sum2_odd);
  1247. sum_odd = _mm256_add_epi32(sum_odd, sum3_odd);
  1248. sum_odd = _mm256_add_epi32(sum_odd, sum4_odd);
  1249. sum_odd = _mm256_add_epi32(sum_odd, sum5_odd);
  1250. sum_odd = _mm256_add_epi32(sum_odd, sum6_odd);
  1251. sum_even = _mm256_add_epi32(sum0_even, sum1_even);
  1252. sum_even = _mm256_add_epi32(sum_even, sum2_even);
  1253. sum_even = _mm256_add_epi32(sum_even, sum3_even);
  1254. sum_even = _mm256_add_epi32(sum_even, sum4_even);
  1255. sum_even = _mm256_add_epi32(sum_even, sum5_even);
  1256. sum_even = _mm256_add_epi32(sum_even, sum6_even);
  1257. __m256i sum_odd_0 = _mm256_unpacklo_epi32(sum_odd, sum_even);
  1258. __m256i sum_even_0 = _mm256_unpackhi_epi32(sum_odd, sum_even);
  1259. __m256i sum_left =
  1260. _mm256_permute2f128_si256(sum_odd_0, sum_even_0, 32);
  1261. __m256i sum_right =
  1262. _mm256_permute2f128_si256(sum_odd_0, sum_even_0, 49);
  1263. sum_left = _mm256_add_epi32(sum_left, bias_val);
  1264. sum_right = _mm256_add_epi32(sum_right, bias_val);
  1265. if (is_quantized) {
  1266. op({{sum_left, sum_right}}, reinterpret_cast<dt_qint8*>(dst0));
  1267. } else {
  1268. _mm256_storeu_si256((__m256i*)(out_ptr0), sum_left);
  1269. _mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right);
  1270. }
  1271. r0 += 16;
  1272. r1 += 16;
  1273. r2 += 16;
  1274. r3 += 16;
  1275. r4 += 16;
  1276. r5 += 16;
  1277. r6 += 16;
  1278. dst0 += 16;
  1279. out_ptr0 += 16;
  1280. }
  1281. r0 += tail_step;
  1282. r1 += tail_step;
  1283. r2 += tail_step;
  1284. r3 += tail_step;
  1285. r4 += tail_step;
  1286. r5 += tail_step;
  1287. r6 += tail_step;
  1288. }
  1289. }
  1290. #undef load_filter
  1291. #undef load_src0
  1292. #undef load_src1
  1293. #undef load_src2
  1294. #undef load_src3
  1295. #undef load_src4
  1296. #undef load_src5
  1297. #undef load_src6
  1298. #undef load_src7
  1299. #define INSTANTIATION(stride, i, bias, is_quantized, Op) \
  1300. template void avx2_chanwise_direct_##stride##_##i##x##i##_int8< \
  1301. bias, is_quantized, Op>(const int8_t*, const int8_t*, \
  1302. const int32_t*, int32_t*, int8_t*, \
  1303. const size_t, const size_t, const size_t, \
  1304. const size_t, const Op&);
  1305. #define FOR_OP(stride, i, is_quantized, bias) \
  1306. INSTANTIATION(stride, i, bias, is_quantized, \
  1307. TypeCvtOp<SIMDType::AVX2 MEGDNN_COMMA dt_qint32 MEGDNN_COMMA \
  1308. dt_qint8>) \
  1309. INSTANTIATION(stride, i, bias, is_quantized, \
  1310. ReluOp<SIMDType::AVX2 MEGDNN_COMMA dt_qint32 MEGDNN_COMMA \
  1311. dt_qint8>) \
  1312. INSTANTIATION(stride, i, bias, is_quantized, \
  1313. HSwishOp<SIMDType::AVX2 MEGDNN_COMMA dt_qint32 MEGDNN_COMMA \
  1314. dt_qint8>)
  1315. #define FOR_BIAS(stride, i, is_quantized) \
  1316. FOR_OP(stride, i, is_quantized, BiasMode::NO_BIAS) \
  1317. FOR_OP(stride, i, is_quantized, BiasMode::BROADCAST_CHANNEL_BIAS)
  1318. #define FOR_QUANTIZED(stride, i) \
  1319. FOR_BIAS(stride, i, true) \
  1320. FOR_BIAS(stride, i, false)
  1321. #define FOR_FILTER(stride) \
  1322. FOR_QUANTIZED(stride, 2) \
  1323. FOR_QUANTIZED(stride, 3) \
  1324. FOR_QUANTIZED(stride, 5) \
  1325. FOR_QUANTIZED(stride, 7)
  1326. #define FOR_STRIDE FOR_FILTER(stride1)
  1327. FOR_STRIDE
  1328. #undef FOR_STRIDE
  1329. #undef FOR_FILTER
  1330. #undef FOR_QUANTIZED
  1331. #undef FOR_BIAS
  1332. #undef FOR_OP
  1333. #undef INSTANTIATION
  1334. } // namespace avx2_chanwise_stride1
  1335. } // namespace x86
  1336. } // namespace megdnn
  1337. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)