You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.h 6.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. #pragma once
  2. #include "src/fallback/general_intrinsic/gi_float.h"
  3. namespace megdnn {
  4. namespace matmul {
  5. namespace fallback {
  6. /* ======================== transform ======================== */
  7. /**
  8. * interleave_INTERLEAVE_UNROLLK_BATCH_type
  9. *
  10. * BATCH means process BATCH * UNROLL_K cols once, BATCH * sizeof(TYPE) *
  11. * UNROLL_K = 16bytes(128bits, a vector size).
  12. *
  13. * the elements traverse order:
  14. * rep(j, 0, INTERLEAVE) rep(i, 0, UNROLL_K) *ouptr++ = inptr[j, i]
  15. */
  16. template <typename T>
  17. static GI_FORCEINLINE void interleave_4x4_1_s(
  18. const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3,
  19. T*& outptr) {
  20. static_assert(sizeof(T) == 4, "interleave_4x4_1_s only support sizeof(T) == 4");
  21. GI_FLOAT32_t d0d1 = GiLoadFloat32(inptr0);
  22. GI_FLOAT32_t d2d3 = GiLoadFloat32(inptr1);
  23. GI_FLOAT32_t d4d5 = GiLoadFloat32(inptr2);
  24. GI_FLOAT32_t d6d7 = GiLoadFloat32(inptr3);
  25. inptr0 += 4;
  26. inptr1 += 4;
  27. inptr2 += 4;
  28. inptr3 += 4;
  29. GiStoreFloat32(outptr, d0d1);
  30. outptr += 4;
  31. GiStoreFloat32(outptr, d2d3);
  32. outptr += 4;
  33. GiStoreFloat32(outptr, d4d5);
  34. outptr += 4;
  35. GiStoreFloat32(outptr, d6d7);
  36. outptr += 4;
  37. }
  38. template <typename T>
  39. static GI_FORCEINLINE void interleave_4x12_1_s(
  40. const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3,
  41. T*& outptr) {
  42. static_assert(sizeof(T) == 4, "interleave_4x12_1_s only support sizeof(T) == 4");
  43. GI_FLOAT32_t d0d1 = GiLoadFloat32(inptr0);
  44. inptr0 += 4;
  45. GI_FLOAT32_t d2d3 = GiLoadFloat32(inptr0);
  46. inptr0 += 4;
  47. GI_FLOAT32_t d4d5 = GiLoadFloat32(inptr0);
  48. inptr0 += 4;
  49. GI_FLOAT32_t d6d7 = GiLoadFloat32(inptr1);
  50. inptr1 += 4;
  51. GI_FLOAT32_t d8d9 = GiLoadFloat32(inptr1);
  52. inptr1 += 4;
  53. GI_FLOAT32_t d10d11 = GiLoadFloat32(inptr1);
  54. inptr1 += 4;
  55. GI_FLOAT32_t d12d13 = GiLoadFloat32(inptr2);
  56. inptr2 += 4;
  57. GI_FLOAT32_t d14d15 = GiLoadFloat32(inptr2);
  58. inptr2 += 4;
  59. GI_FLOAT32_t d16d17 = GiLoadFloat32(inptr2);
  60. inptr2 += 4;
  61. GI_FLOAT32_t d18d19 = GiLoadFloat32(inptr3);
  62. inptr3 += 4;
  63. GI_FLOAT32_t d20d21 = GiLoadFloat32(inptr3);
  64. inptr3 += 4;
  65. GI_FLOAT32_t d22d23 = GiLoadFloat32(inptr3);
  66. inptr3 += 4;
  67. GiStoreFloat32(outptr, d0d1);
  68. outptr += 4;
  69. GiStoreFloat32(outptr, d2d3);
  70. outptr += 4;
  71. GiStoreFloat32(outptr, d4d5);
  72. outptr += 4;
  73. GiStoreFloat32(outptr, d6d7);
  74. outptr += 4;
  75. GiStoreFloat32(outptr, d8d9);
  76. outptr += 4;
  77. GiStoreFloat32(outptr, d10d11);
  78. outptr += 4;
  79. GiStoreFloat32(outptr, d12d13);
  80. outptr += 4;
  81. GiStoreFloat32(outptr, d14d15);
  82. outptr += 4;
  83. GiStoreFloat32(outptr, d16d17);
  84. outptr += 4;
  85. GiStoreFloat32(outptr, d18d19);
  86. outptr += 4;
  87. GiStoreFloat32(outptr, d20d21);
  88. outptr += 4;
  89. GiStoreFloat32(outptr, d22d23);
  90. outptr += 4;
  91. }
  92. template <typename T>
  93. static GI_FORCEINLINE void interleave_1x12_1_s(const T*& inptr0, T*& outptr) {
  94. static_assert(sizeof(T) == 4, "interleave_1x12_1_s only support sizeof(T) == 4");
  95. GI_FLOAT32_t d0d1 = GiLoadFloat32(inptr0);
  96. inptr0 += 4;
  97. GI_FLOAT32_t d2d3 = GiLoadFloat32(inptr0);
  98. inptr0 += 4;
  99. GI_FLOAT32_t d4d5 = GiLoadFloat32(inptr0);
  100. inptr0 += 4;
  101. GiStoreFloat32(outptr, d0d1);
  102. outptr += 4;
  103. GiStoreFloat32(outptr, d2d3);
  104. outptr += 4;
  105. GiStoreFloat32(outptr, d4d5);
  106. outptr += 4;
  107. }
  108. template <typename T>
  109. static GI_FORCEINLINE void interleave_1x4_1_s(const T*& inptr0, T*& outptr) {
  110. static_assert(sizeof(T) == 4, "interleave_1x4_1_s only support sizeof(T) == 4");
  111. GI_FLOAT32_t d0d1 = GiLoadFloat32(inptr0);
  112. inptr0 += 4;
  113. GiStoreFloat32(outptr, d0d1);
  114. outptr += 4;
  115. }
  116. template <typename T>
  117. static GI_FORCEINLINE void interleave_helper(
  118. const T*& inptr, T*& outptr, int unroll_k, int ksize, T val = 0) {
  119. int k = 0;
  120. for (; k < ksize; k++) {
  121. *outptr++ = *inptr++;
  122. }
  123. for (; k < unroll_k; k++) {
  124. *outptr++ = val;
  125. }
  126. }
  127. template <typename T>
  128. static GI_FORCEINLINE void interleave_1(
  129. const T*& inptr0, T*& outptr, int unroll_k, int ksize, T val = 0) {
  130. for (int k = 0; k < ksize; k += unroll_k) {
  131. int size = std::min(unroll_k, ksize - k);
  132. interleave_helper(inptr0, outptr, unroll_k, size, val);
  133. }
  134. }
  135. template <typename T>
  136. static GI_FORCEINLINE void interleave_4(
  137. const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3,
  138. T*& outptr, int unroll_k, int ksize, T val = 0) {
  139. for (int k = 0; k < ksize; k += unroll_k) {
  140. int size = std::min(unroll_k, ksize - k);
  141. interleave_helper(inptr0, outptr, unroll_k, size, val);
  142. interleave_helper(inptr1, outptr, unroll_k, size, val);
  143. interleave_helper(inptr2, outptr, unroll_k, size, val);
  144. interleave_helper(inptr3, outptr, unroll_k, size, val);
  145. }
  146. }
  147. /* ======================== transpose pack B ======================== */
  148. /**
  149. * transpose_INTERLEAVE_UNROLLK_BATCH_type
  150. *
  151. * BATCH means process BATCH * INTERLEAVE cols once, BATCH * sizeof(TYPE) *
  152. * INTERLEAVE = 16bytes(128bits, a vector size).
  153. *
  154. * the elements traverse order:
  155. * rep(j, 0, INTERLEAVE) rep(i, 0, UNROLL_K) *ouptr++ = inptr[i, j]
  156. */
  157. template <typename T>
  158. static GI_FORCEINLINE void transpose_4x4_1_s(
  159. const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3,
  160. T*& outptr, int stride = 16) {
  161. static_assert(sizeof(T) == 4, "transpose_4x4_1_s only support sizeof(T) == 4");
  162. stride = stride / sizeof(float);
  163. stride -= 2;
  164. GI_FLOAT32_t d0d1 = GiLoadFloat32(inptr0);
  165. GI_FLOAT32_t d2d3 = GiLoadFloat32(inptr1);
  166. GI_FLOAT32_t d4d5 = GiLoadFloat32(inptr2);
  167. GI_FLOAT32_t d6d7 = GiLoadFloat32(inptr3);
  168. inptr0 += 4;
  169. inptr1 += 4;
  170. inptr2 += 4;
  171. inptr3 += 4;
  172. GI_FLOAT32_V2_t q0q1 = GiZipqFloat32(d0d1, d2d3);
  173. GI_FLOAT32_V2_t q2q3 = GiZipqFloat32(d4d5, d6d7);
  174. GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(q0q1, 0)));
  175. outptr += 2;
  176. GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(q2q3, 0)));
  177. outptr += stride;
  178. GiSt1Float32(outptr, GiGetHighFloat32(GiGetSubVectorFloat32V2(q0q1, 0)));
  179. outptr += 2;
  180. GiSt1Float32(outptr, GiGetHighFloat32(GiGetSubVectorFloat32V2(q2q3, 0)));
  181. outptr += stride;
  182. GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(q0q1, 1)));
  183. outptr += 2;
  184. GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(q2q3, 1)));
  185. outptr += stride;
  186. GiSt1Float32(outptr, GiGetHighFloat32(GiGetSubVectorFloat32V2(q0q1, 1)));
  187. outptr += 2;
  188. GiSt1Float32(outptr, GiGetHighFloat32(GiGetSubVectorFloat32V2(q2q3, 1)));
  189. outptr += stride;
  190. }
  191. } // namespace fallback
  192. } // namespace matmul
  193. } // namespace megdnn
  194. // vim: syntax=cpp.doxygen