You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_sse_4x8x2.h 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680
  1. /**
  2. * \file dnn/src/x86/matrix_mul/int8/kernel_sse_4x8x2.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include <immintrin.h>
  13. #ifdef WIN32
  14. #include <avx2intrin.h>
  15. #include <avxintrin.h>
  16. #include <fmaintrin.h>
  17. #include <smmintrin.h>
  18. #endif
  19. #include <cmath>
  20. #include <cstdint>
  21. #include <type_traits>
  22. #include "src/common/utils.h"
  23. #include "src/x86/matrix_mul/common/common.h"
  24. namespace megdnn {
  25. namespace x86 {
  26. namespace matmul_sse_4x8x2 {
  27. template <typename CType>
  28. MEGDNN_ATTRIBUTE_TARGET("sse4.1")
  29. void store_overflow(void* ptr, __m128i a);
  30. template <>
  31. void store_overflow<int16_t>(void* ptr, __m128i a) {
  32. a = _mm_shufflelo_epi16(a, 0x08);
  33. a = _mm_shufflehi_epi16(a, 0x08);
  34. a = _mm_shuffle_epi32(a, 0x08);
  35. _mm_storel_epi64((__m128i*)ptr, a);
  36. }
  37. template <>
  38. void store_overflow<int32_t>(void* ptr, __m128i a) {
  39. _mm_storeu_si128((__m128i*)(ptr), a);
  40. }
  41. template <typename CType>
  42. MEGDNN_ATTRIBUTE_TARGET("sse4.1")
  43. void store_overflow(void* ptr, __m128i a, int remain);
  44. template <>
  45. void store_overflow<int16_t>(void* ptr, __m128i a, int remain) {
  46. __m128i mask = _mm_continue_mask(remain * sizeof(int16_t));
  47. a = _mm_shufflelo_epi16(a, 0x08);
  48. a = _mm_shufflehi_epi16(a, 0x08);
  49. a = _mm_shuffle_epi32(a, 0x08);
  50. _mm_maskmoveu_si128(a, mask, reinterpret_cast<char*>(ptr));
  51. }
  52. template <>
  53. void store_overflow<int32_t>(void* ptr, __m128i a, int remain) {
  54. __m128i mask = _mm_continue_mask(remain * sizeof(int32_t));
  55. _mm_maskmoveu_si128(a, mask, reinterpret_cast<char*>(ptr));
  56. }
  57. template <typename CType>
  58. MEGDNN_ATTRIBUTE_TARGET("sse4.1")
  59. static inline void kern_gemm_s8s8s32_sse_4x8x2(const int16_t* pack_a_ptr,
  60. const int8_t* pack_b_ptr,
  61. CType* c_ptr, const int ldc,
  62. const int k) {
  63. constexpr int k_step = 2;
  64. __m128i a_vec[2];
  65. __m128i b_vec[2];
  66. __m128i c_vec[4 * 2];
  67. __m128i c_temp[4];
  68. b_vec[0] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr);
  69. b_vec[1] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr + 8);
  70. a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr));
  71. a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 2));
  72. c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]);
  73. c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]);
  74. c_vec[0] = _mm_setzero_si128();
  75. c_vec[0] = _mm_add_epi32(c_vec[0], c_temp[0]);
  76. c_vec[1] = _mm_setzero_si128();
  77. c_vec[1] = _mm_add_epi32(c_vec[1], c_temp[1]);
  78. c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]);
  79. c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]);
  80. c_vec[2] = _mm_setzero_si128();
  81. c_vec[2] = _mm_add_epi32(c_vec[2], c_temp[2]);
  82. c_vec[3] = _mm_setzero_si128();
  83. c_vec[3] = _mm_add_epi32(c_vec[3], c_temp[3]);
  84. a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 4));
  85. a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 6));
  86. c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]);
  87. c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]);
  88. c_vec[4] = _mm_setzero_si128();
  89. c_vec[4] = _mm_add_epi32(c_vec[4], c_temp[0]);
  90. c_vec[5] = _mm_setzero_si128();
  91. c_vec[5] = _mm_add_epi32(c_vec[5], c_temp[1]);
  92. c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]);
  93. c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]);
  94. c_vec[6] = _mm_setzero_si128();
  95. c_vec[6] = _mm_add_epi32(c_vec[6], c_temp[2]);
  96. c_vec[7] = _mm_setzero_si128();
  97. c_vec[7] = _mm_add_epi32(c_vec[7], c_temp[3]);
  98. pack_a_ptr += 8;
  99. pack_b_ptr += 16;
  100. for (int iter_k = 2; iter_k < k; iter_k += k_step) {
  101. b_vec[0] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr);
  102. b_vec[1] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr + 8);
  103. a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr));
  104. a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 2));
  105. c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]);
  106. c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]);
  107. c_vec[0] = _mm_add_epi32(c_vec[0], c_temp[0]);
  108. c_vec[1] = _mm_add_epi32(c_vec[1], c_temp[1]);
  109. c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]);
  110. c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]);
  111. c_vec[2] = _mm_add_epi32(c_vec[2], c_temp[2]);
  112. c_vec[3] = _mm_add_epi32(c_vec[3], c_temp[3]);
  113. a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 4));
  114. a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 6));
  115. c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]);
  116. c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]);
  117. c_vec[4] = _mm_add_epi32(c_vec[4], c_temp[0]);
  118. c_vec[5] = _mm_add_epi32(c_vec[5], c_temp[1]);
  119. c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]);
  120. c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]);
  121. c_vec[6] = _mm_add_epi32(c_vec[6], c_temp[2]);
  122. c_vec[7] = _mm_add_epi32(c_vec[7], c_temp[3]);
  123. pack_a_ptr += 8;
  124. pack_b_ptr += 16;
  125. }
  126. store_overflow<CType>(c_ptr, c_vec[0]);
  127. store_overflow<CType>(c_ptr + 4, c_vec[1]);
  128. store_overflow<CType>(c_ptr + ldc, c_vec[2]);
  129. store_overflow<CType>(c_ptr + ldc + 4, c_vec[3]);
  130. store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]);
  131. store_overflow<CType>(c_ptr + 2 * ldc + 4, c_vec[5]);
  132. store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6]);
  133. store_overflow<CType>(c_ptr + 3 * ldc + 4, c_vec[7]);
  134. }
  135. template <typename CType>
  136. MEGDNN_ATTRIBUTE_TARGET("sse4.1")
  137. static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m(
  138. const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr,
  139. const int ldc, const int k, const int remain_m) {
  140. constexpr int k_step = 2;
  141. __m128i a_vec[2];
  142. __m128i b_vec[2];
  143. __m128i c_vec[4 * 2];
  144. __m128i c_temp[4];
  145. b_vec[0] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr);
  146. b_vec[1] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr + 8);
  147. a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr));
  148. a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 2));
  149. c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]);
  150. c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]);
  151. c_vec[0] = _mm_setzero_si128();
  152. c_vec[0] = _mm_add_epi32(c_vec[0], c_temp[0]);
  153. c_vec[1] = _mm_setzero_si128();
  154. c_vec[1] = _mm_add_epi32(c_vec[1], c_temp[1]);
  155. c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]);
  156. c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]);
  157. c_vec[2] = _mm_setzero_si128();
  158. c_vec[2] = _mm_add_epi32(c_vec[2], c_temp[2]);
  159. c_vec[3] = _mm_setzero_si128();
  160. c_vec[3] = _mm_add_epi32(c_vec[3], c_temp[3]);
  161. a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 4));
  162. a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 6));
  163. c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]);
  164. c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]);
  165. c_vec[4] = _mm_setzero_si128();
  166. c_vec[4] = _mm_add_epi32(c_vec[4], c_temp[0]);
  167. c_vec[5] = _mm_setzero_si128();
  168. c_vec[5] = _mm_add_epi32(c_vec[5], c_temp[1]);
  169. c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]);
  170. c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]);
  171. c_vec[6] = _mm_setzero_si128();
  172. c_vec[6] = _mm_add_epi32(c_vec[6], c_temp[2]);
  173. c_vec[7] = _mm_setzero_si128();
  174. c_vec[7] = _mm_add_epi32(c_vec[7], c_temp[3]);
  175. pack_a_ptr += 8;
  176. pack_b_ptr += 16;
  177. for (int iter_k = 2; iter_k < k; iter_k += k_step) {
  178. b_vec[0] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr);
  179. b_vec[1] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr + 8);
  180. a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr));
  181. a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 2));
  182. c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]);
  183. c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]);
  184. c_vec[0] = _mm_add_epi32(c_vec[0], c_temp[0]);
  185. c_vec[1] = _mm_add_epi32(c_vec[1], c_temp[1]);
  186. c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]);
  187. c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]);
  188. c_vec[2] = _mm_add_epi32(c_vec[2], c_temp[2]);
  189. c_vec[3] = _mm_add_epi32(c_vec[3], c_temp[3]);
  190. a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 4));
  191. a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 6));
  192. c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]);
  193. c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]);
  194. c_vec[4] = _mm_add_epi32(c_vec[4], c_temp[0]);
  195. c_vec[5] = _mm_add_epi32(c_vec[5], c_temp[1]);
  196. c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]);
  197. c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]);
  198. c_vec[6] = _mm_add_epi32(c_vec[6], c_temp[2]);
  199. c_vec[7] = _mm_add_epi32(c_vec[7], c_temp[3]);
  200. pack_a_ptr += 8;
  201. pack_b_ptr += 16;
  202. }
  203. store_overflow<CType>(c_ptr, c_vec[0]);
  204. store_overflow<CType>(c_ptr + 4, c_vec[1]);
  205. switch (remain_m) {
  206. case 2:
  207. store_overflow<CType>(c_ptr + ldc, c_vec[2]);
  208. store_overflow<CType>(c_ptr + ldc + 4, c_vec[3]);
  209. break;
  210. case 3:
  211. store_overflow<CType>(c_ptr + ldc, c_vec[2]);
  212. store_overflow<CType>(c_ptr + ldc + 4, c_vec[3]);
  213. store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]);
  214. store_overflow<CType>(c_ptr + 2 * ldc + 4, c_vec[5]);
  215. break;
  216. case 4:
  217. store_overflow<CType>(c_ptr + ldc, c_vec[2]);
  218. store_overflow<CType>(c_ptr + ldc + 4, c_vec[3]);
  219. store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]);
  220. store_overflow<CType>(c_ptr + 2 * ldc + 4, c_vec[5]);
  221. store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6]);
  222. store_overflow<CType>(c_ptr + 3 * ldc + 4, c_vec[7]);
  223. default:
  224. break;
  225. }
  226. }
  227. template <typename CType>
  228. MEGDNN_ATTRIBUTE_TARGET("sse4.1")
  229. static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_n(
  230. const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr,
  231. const int ldc, const int k, int remain_n) {
  232. constexpr int k_step = 2;
  233. __m128i a_vec[2];
  234. __m128i b_vec[2];
  235. __m128i c_vec[4 * 2];
  236. __m128i c_temp[4];
  237. b_vec[0] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr);
  238. b_vec[1] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr + 8);
  239. a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr));
  240. a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 2));
  241. c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]);
  242. c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]);
  243. c_vec[0] = _mm_setzero_si128();
  244. c_vec[0] = _mm_add_epi32(c_vec[0], c_temp[0]);
  245. c_vec[1] = _mm_setzero_si128();
  246. c_vec[1] = _mm_add_epi32(c_vec[1], c_temp[1]);
  247. c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]);
  248. c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]);
  249. c_vec[2] = _mm_setzero_si128();
  250. c_vec[2] = _mm_add_epi32(c_vec[2], c_temp[2]);
  251. c_vec[3] = _mm_setzero_si128();
  252. c_vec[3] = _mm_add_epi32(c_vec[3], c_temp[3]);
  253. a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 4));
  254. a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 6));
  255. c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]);
  256. c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]);
  257. c_vec[4] = _mm_setzero_si128();
  258. c_vec[4] = _mm_add_epi32(c_vec[4], c_temp[0]);
  259. c_vec[5] = _mm_setzero_si128();
  260. c_vec[5] = _mm_add_epi32(c_vec[5], c_temp[1]);
  261. c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]);
  262. c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]);
  263. c_vec[6] = _mm_setzero_si128();
  264. c_vec[6] = _mm_add_epi32(c_vec[6], c_temp[2]);
  265. c_vec[7] = _mm_setzero_si128();
  266. c_vec[7] = _mm_add_epi32(c_vec[7], c_temp[3]);
  267. pack_a_ptr += 8;
  268. pack_b_ptr += 16;
  269. for (int iter_k = 2; iter_k < k; iter_k += k_step) {
  270. b_vec[0] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr);
  271. b_vec[1] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr + 8);
  272. a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr));
  273. a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 2));
  274. c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]);
  275. c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]);
  276. c_vec[0] = _mm_add_epi32(c_vec[0], c_temp[0]);
  277. c_vec[1] = _mm_add_epi32(c_vec[1], c_temp[1]);
  278. c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]);
  279. c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]);
  280. c_vec[2] = _mm_add_epi32(c_vec[2], c_temp[2]);
  281. c_vec[3] = _mm_add_epi32(c_vec[3], c_temp[3]);
  282. a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 4));
  283. a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 6));
  284. c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]);
  285. c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]);
  286. c_vec[4] = _mm_add_epi32(c_vec[4], c_temp[0]);
  287. c_vec[5] = _mm_add_epi32(c_vec[5], c_temp[1]);
  288. c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]);
  289. c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]);
  290. c_vec[6] = _mm_add_epi32(c_vec[6], c_temp[2]);
  291. c_vec[7] = _mm_add_epi32(c_vec[7], c_temp[3]);
  292. pack_a_ptr += 8;
  293. pack_b_ptr += 16;
  294. }
  295. if (remain_n >= 4) {
  296. store_overflow<CType>(c_ptr, c_vec[0]);
  297. store_overflow<CType>(c_ptr + ldc, c_vec[2]);
  298. store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]);
  299. store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6]);
  300. c_ptr += 4;
  301. remain_n -= 4;
  302. c_vec[0] = c_vec[1];
  303. c_vec[2] = c_vec[3];
  304. c_vec[4] = c_vec[5];
  305. c_vec[6] = c_vec[7];
  306. }
  307. store_overflow<CType>(c_ptr, c_vec[0], remain_n);
  308. store_overflow<CType>(c_ptr + ldc, c_vec[2], remain_n);
  309. store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4], remain_n);
  310. store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6], remain_n);
  311. }
  312. template <typename CType>
  313. MEGDNN_ATTRIBUTE_TARGET("sse4.1")
  314. static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m_n(
  315. const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr,
  316. const int ldc, const int k, int remain_m, int remain_n) {
  317. constexpr int k_step = 2;
  318. __m128i a_vec[2];
  319. __m128i b_vec[2];
  320. __m128i c_vec[4 * 2];
  321. __m128i c_temp[4];
  322. b_vec[0] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr);
  323. b_vec[1] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr + 8);
  324. a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr));
  325. a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 2));
  326. c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]);
  327. c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]);
  328. c_vec[0] = _mm_setzero_si128();
  329. c_vec[0] = _mm_add_epi32(c_vec[0], c_temp[0]);
  330. c_vec[1] = _mm_setzero_si128();
  331. c_vec[1] = _mm_add_epi32(c_vec[1], c_temp[1]);
  332. c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]);
  333. c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]);
  334. c_vec[2] = _mm_setzero_si128();
  335. c_vec[2] = _mm_add_epi32(c_vec[2], c_temp[2]);
  336. c_vec[3] = _mm_setzero_si128();
  337. c_vec[3] = _mm_add_epi32(c_vec[3], c_temp[3]);
  338. a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 4));
  339. a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 6));
  340. c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]);
  341. c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]);
  342. c_vec[4] = _mm_setzero_si128();
  343. c_vec[4] = _mm_add_epi32(c_vec[4], c_temp[0]);
  344. c_vec[5] = _mm_setzero_si128();
  345. c_vec[5] = _mm_add_epi32(c_vec[5], c_temp[1]);
  346. c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]);
  347. c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]);
  348. c_vec[6] = _mm_setzero_si128();
  349. c_vec[6] = _mm_add_epi32(c_vec[6], c_temp[2]);
  350. c_vec[7] = _mm_setzero_si128();
  351. c_vec[7] = _mm_add_epi32(c_vec[7], c_temp[3]);
  352. pack_a_ptr += 8;
  353. pack_b_ptr += 16;
  354. for (int iter_k = 2; iter_k < k; iter_k += k_step) {
  355. b_vec[0] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr);
  356. b_vec[1] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr + 8);
  357. a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr));
  358. a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 2));
  359. c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]);
  360. c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]);
  361. c_vec[0] = _mm_add_epi32(c_vec[0], c_temp[0]);
  362. c_vec[1] = _mm_add_epi32(c_vec[1], c_temp[1]);
  363. c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]);
  364. c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]);
  365. c_vec[2] = _mm_add_epi32(c_vec[2], c_temp[2]);
  366. c_vec[3] = _mm_add_epi32(c_vec[3], c_temp[3]);
  367. a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 4));
  368. a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 6));
  369. c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]);
  370. c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]);
  371. c_vec[4] = _mm_add_epi32(c_vec[4], c_temp[0]);
  372. c_vec[5] = _mm_add_epi32(c_vec[5], c_temp[1]);
  373. c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]);
  374. c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]);
  375. c_vec[6] = _mm_add_epi32(c_vec[6], c_temp[2]);
  376. c_vec[7] = _mm_add_epi32(c_vec[7], c_temp[3]);
  377. pack_a_ptr += 8;
  378. pack_b_ptr += 16;
  379. }
  380. int index_array[4]{0, 2, 4, 6};
  381. if (remain_n >= 4) {
  382. for (int m = 0; m < remain_m; ++m) {
  383. store_overflow<CType>(c_ptr + m * ldc, c_vec[index_array[m]]);
  384. }
  385. c_ptr += 4;
  386. remain_n -= 4;
  387. c_vec[0] = c_vec[1];
  388. c_vec[2] = c_vec[3];
  389. c_vec[4] = c_vec[5];
  390. c_vec[6] = c_vec[7];
  391. }
  392. for (int m = 0; m < remain_m; ++m) {
  393. store_overflow<CType>(c_ptr + m * ldc, c_vec[index_array[m]], remain_n);
  394. }
  395. }
  396. static inline void gemm_s8s8s32_sse_4x8x2_pack_an(dt_int16* out,
  397. const dt_int8* in, int ldin,
  398. int m_start, int m_max,
  399. int k_start, int k_max) {
  400. constexpr int tile_m = 4;
  401. constexpr int tile_k_step = 8;
  402. constexpr int tile_k = 2;
  403. constexpr int tile_len = tile_m * tile_k_step;
  404. const int k_size = k_max - k_start;
  405. const int m_end = (m_max - m_start) / tile_m * tile_m + m_start;
  406. const int m_remain = m_max - m_end;
  407. for (int m = m_start; m < m_end; m += tile_m) {
  408. const dt_int8* in0 = in + m * ldin + k_start;
  409. const dt_int8* in1 = in0 + ldin;
  410. const dt_int8* in2 = in1 + ldin;
  411. const dt_int8* in3 = in2 + ldin;
  412. int remain_k = k_size;
  413. for (; remain_k >= tile_k_step; remain_k -= tile_k_step) {
  414. transpose_4x8_k2_int8_to_int16(in0, in1, in2, in3, out);
  415. out += tile_len;
  416. in0 += tile_k_step;
  417. in1 += tile_k_step;
  418. in2 += tile_k_step;
  419. in3 += tile_k_step;
  420. }
  421. if (remain_k > 0) {
  422. transpose_4xk_int8_to_int16_pad(in0, in1, in2, in3, out, remain_k);
  423. out += tile_m * round_up(remain_k, tile_k);
  424. }
  425. }
  426. if (m_remain > 0) {
  427. dt_int8 zerobuff[tile_k_step];
  428. std::memset(zerobuff, 0, sizeof(int8_t) * tile_k_step);
  429. const dt_int8* in0 = in + m_end * ldin + k_start;
  430. const dt_int8* in1 = in0 + ldin;
  431. const dt_int8* in2 = in1 + ldin;
  432. const dt_int8* in3 = &zerobuff[0];
  433. int in1_step = tile_k_step;
  434. int in2_step = tile_k_step;
  435. if (m_remain < 3) {
  436. in2 = &zerobuff[0];
  437. in2_step = 0;
  438. }
  439. if (m_remain < 2) {
  440. in1 = &zerobuff[0];
  441. in1_step = 0;
  442. }
  443. int remain_k = k_size;
  444. for (; remain_k >= tile_k_step; remain_k -= tile_k_step) {
  445. transpose_4x8_k2_int8_to_int16(in0, in1, in2, in3, out);
  446. out += tile_len;
  447. in0 += tile_k_step;
  448. in1 += in1_step;
  449. in2 += in2_step;
  450. }
  451. if (remain_k > 0) {
  452. transpose_4xk_int8_to_int16_pad(in0, in1, in2, in3, out, remain_k);
  453. out += tile_m * round_up(remain_k, tile_k);
  454. in0 += tile_k_step;
  455. in1 += in1_step;
  456. in2 += in2_step;
  457. }
  458. }
  459. }
  460. static inline void gemm_s8s8s32_sse_4x8x2_pack_bn(dt_int8* out,
  461. const dt_int8* in, int ldin,
  462. int n_start, int n_max,
  463. int k_start, int k_max) {
  464. constexpr int tile_n = 8;
  465. constexpr int tile_k = 2;
  466. constexpr int tile_len = tile_n * tile_k;
  467. const int k_size = k_max - k_start;
  468. const int k_end = k_size / tile_k * tile_k + k_start;
  469. const int k_remain = k_max - k_end;
  470. const int n_size = n_max - n_start;
  471. const int n_end = n_size / tile_n * tile_n + n_start;
  472. const int n_remain = n_max - n_end;
  473. const int pack_line_len = round_up(k_size, tile_k) * tile_n;
  474. int k = k_start;
  475. for (; k < k_end; k += tile_k) {
  476. int8_t* outptr = out;
  477. for (int n = n_start; n < n_end; n += tile_n) {
  478. const dt_int8* inptr_0 = in + k * ldin + n;
  479. const dt_int8* inptr_1 = inptr_0 + ldin;
  480. transpose_2x8_no_inc(inptr_0, inptr_1, outptr);
  481. outptr += pack_line_len;
  482. }
  483. if (n_end < n_max) {
  484. naive_transpose_kn_pad(outptr, in + k * ldin + n_end, ldin, tile_k,
  485. n_remain, tile_k, tile_n);
  486. }
  487. out += tile_len;
  488. }
  489. if (k_remain > 0) {
  490. int8_t* outptr = out;
  491. dt_int8 zerobuff[tile_n];
  492. std::memset(zerobuff, 0, sizeof(int8_t) * tile_n);
  493. for (int n = n_start; n < n_end; n += tile_n) {
  494. const dt_int8* inptr_0 = in + k * ldin + n;
  495. const dt_int8* inptr_1 = &zerobuff[0];
  496. transpose_2x8_no_inc(inptr_0, inptr_1, outptr);
  497. outptr += pack_line_len;
  498. }
  499. if (n_end < n_max) {
  500. naive_transpose_kn_pad(outptr, in + k * ldin + n_end, ldin,
  501. k_remain, n_remain, tile_k, tile_n);
  502. }
  503. }
  504. }
  505. static inline void gemm_s8s8s32_sse_4x8x2_pack_bt(dt_int8* out,
  506. const dt_int8* in, int ldin,
  507. int n_start, int n_max,
  508. int k_start, int k_max) {
  509. constexpr int tile_n = 8;
  510. constexpr int tile_k = 2;
  511. constexpr int tile_k_step = 16;
  512. const int k_size = k_max - k_start;
  513. const int k_end = k_size / tile_k_step * tile_k_step + k_start;
  514. const int k_remain = k_max - k_end;
  515. const int n_size = n_max - n_start;
  516. const int n_end = n_size / tile_n * tile_n + n_start;
  517. const int n_remain = n_max - n_end;
  518. for (int n = n_start; n < n_end; n += tile_n) {
  519. const dt_int8* in0 = in + n * ldin + k_start;
  520. const dt_int8* in1 = in0 + ldin;
  521. const dt_int8* in2 = in1 + ldin;
  522. const dt_int8* in3 = in2 + ldin;
  523. const dt_int8* in4 = in3 + ldin;
  524. const dt_int8* in5 = in4 + ldin;
  525. const dt_int8* in6 = in5 + ldin;
  526. const dt_int8* in7 = in6 + ldin;
  527. for (int k = k_start; k < k_end; k += tile_k_step) {
  528. transpose_8x16_k2(out, in0, in1, in2, in3, in4, in5, in6, in7);
  529. in0 += tile_k_step;
  530. in1 += tile_k_step;
  531. in2 += tile_k_step;
  532. in3 += tile_k_step;
  533. in4 += tile_k_step;
  534. in5 += tile_k_step;
  535. in6 += tile_k_step;
  536. in7 += tile_k_step;
  537. out += tile_n * tile_k_step;
  538. }
  539. naive_transpose_8xk_k2(out, in0, in1, in2, in3, in4, in5, in6, in7,
  540. k_remain);
  541. out += tile_n * round_up(k_remain, tile_k);
  542. }
  543. if (n_remain > 0) {
  544. const dt_int8* in0 = in + n_end * ldin + k_start;
  545. naive_transpose_nk_k2(out, in0, ldin, n_remain, k_size, tile_n);
  546. }
  547. }
  548. static inline void gemm_s8s8s32_sse_4x8x2_pack_at(dt_int16* out,
  549. const dt_int8* in, int ldin,
  550. int m_start, int m_max,
  551. int k_start, int k_max) {
  552. constexpr int tile_m = 8;
  553. constexpr int tile_m_step = 4;
  554. constexpr int tile_k = 2;
  555. const int k_size = k_max - k_start;
  556. const int k_end = k_size / tile_k * tile_k + k_start;
  557. const int k_remain = k_max - k_end;
  558. const int m_size = m_max - m_start;
  559. const int m_end = m_size / tile_m * tile_m + m_start;
  560. const int pack_line_len = round_up(k_size, tile_k) * tile_m_step;
  561. int k = k_start;
  562. for (; k < k_end; k += tile_k) {
  563. dt_int16* outptr = out;
  564. for (int m = m_start; m < m_end; m += tile_m) {
  565. const dt_int8* inptr_0 = in + k * ldin + m;
  566. const dt_int8* inptr_1 = inptr_0 + ldin;
  567. transpose_km_2x8_k2_tile4_int8_to_int16(inptr_0, inptr_1, outptr,
  568. pack_line_len);
  569. outptr += (tile_m / tile_m_step) * pack_line_len;
  570. }
  571. if (m_end < m_max) {
  572. for (int m = m_end; m < m_max; m += tile_m_step) {
  573. const int m_remain =
  574. m_max - m >= tile_m_step ? tile_m_step : m_max - m;
  575. naive_transpose_kn_pad(outptr, in + k * ldin + m, ldin, tile_k,
  576. m_remain, tile_k, tile_m_step);
  577. outptr += pack_line_len;
  578. }
  579. }
  580. out += tile_m_step * tile_k;
  581. }
  582. if (k_remain > 0) {
  583. dt_int16* outptr = out;
  584. dt_int8 zerobuff[tile_m];
  585. std::memset(zerobuff, 0, sizeof(int8_t) * tile_m);
  586. for (int n = m_start; n < m_end; n += tile_m) {
  587. const dt_int8* inptr_0 = in + k * ldin + n;
  588. const dt_int8* inptr_1 = &zerobuff[0];
  589. transpose_km_2x8_k2_tile4_int8_to_int16(inptr_0, inptr_1, outptr,
  590. pack_line_len);
  591. outptr += (tile_m / tile_m_step) * pack_line_len;
  592. }
  593. if (m_end < m_max) {
  594. for (int m = m_end; m < m_max; m += tile_m_step) {
  595. const int m_remain =
  596. m_max - m >= tile_m_step ? tile_m_step : m_max - m;
  597. naive_transpose_kn_pad(outptr, in + k * ldin + m, ldin,
  598. k_remain, m_remain, tile_k, tile_m_step);
  599. outptr += pack_line_len;
  600. }
  601. }
  602. }
  603. }
  604. } // namespace matmul_sse_4x8x2
  605. } // namespace x86
  606. } // namespace megdnn
  607. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台