You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

strategy.cpp 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. /**
  2. * \file dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/aarch64/matrix_mul/int8_dot/strategy.h"
  12. #if MGB_ENABLE_DOT
  13. #include "src/aarch64/matrix_mul/asm/common.h"
  14. #include "src/arm_common/simd_macro/marm_neon.h"
  15. #include "src/common/utils.h"
  16. #include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h"
  17. #include "src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h"
  18. using namespace megdnn;
  19. using namespace aarch64;
  20. using namespace aarch64::matmul;
  21. /* ====================== gemm_s8_8x12 ===========================*/
  22. MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12);
  23. void gemm_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin,
  24. int y0, int ymax, int k0, int kmax,
  25. bool transpose) const {
  26. if (transpose) {
  27. matmul_8x12x4::gemm_s8_8x12_pack_A_t(outptr, inptr, ldin, y0, ymax, k0,
  28. kmax);
  29. } else {
  30. matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
  31. kmax);
  32. }
  33. }
  34. void gemm_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
  35. int xmax, int k0, int kmax, bool transpose) const {
  36. if (transpose) {
  37. matmul_8x12x4::gemm_s8_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax);
  38. } else {
  39. matmul_8x12x4::gemm_s8_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
  40. }
  41. }
  42. void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
  43. size_t N, size_t K, dt_int32* C, size_t LDC,
  44. bool is_first_k, const dt_int32*, dt_int32*) const {
  45. megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
  46. ((A_dtype.enumv() == DTypeEnum::Int8 &&
  47. C_dtype.enumv() == DTypeEnum::Int32) ||
  48. (A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
  49. C_dtype.enumv() == DTypeEnum::QuantizedS32)),
  50. "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
  51. C_dtype.name());
  52. MEGDNN_MARK_USED_VAR(A_dtype);
  53. MEGDNN_MARK_USED_VAR(B_dtype);
  54. MEGDNN_MARK_USED_VAR(C_dtype);
  55. constexpr size_t A_INTERLEAVE = 8;
  56. constexpr size_t B_INTERLEAVE = 12;
  57. //! K is packed to times of 4
  58. K = round_up<size_t>(K, 4);
  59. const int K8 = (K << 3);
  60. const int K12 = K * 12;
  61. const int K4 = K * 4;
  62. size_t m = 0;
  63. for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
  64. int32_t* output = C + (m * LDC);
  65. size_t n = 0;
  66. const dt_int8* cur_packB = packB;
  67. for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
  68. matmul_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC,
  69. is_first_k);
  70. output += B_INTERLEAVE;
  71. cur_packB += K12;
  72. }
  73. for (; n < N; n += 4) {
  74. matmul_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC,
  75. is_first_k, std::min<size_t>(N - n, 4));
  76. output += 4;
  77. cur_packB += K4;
  78. }
  79. packA += K8;
  80. }
  81. for (; m < M; m += 4) {
  82. int32_t* output = C + (m * LDC);
  83. const dt_int8* cur_packB = packB;
  84. size_t n = 0;
  85. for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
  86. matmul_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC,
  87. is_first_k, std::min<size_t>(M - m, 4));
  88. output += B_INTERLEAVE;
  89. cur_packB += K12;
  90. }
  91. for (; n < N; n += 4) {
  92. matmul_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC,
  93. is_first_k, std::min<size_t>(M - m, 4),
  94. std::min<size_t>(N - n, 4));
  95. output += 4;
  96. cur_packB += K4;
  97. }
  98. packA += K4;
  99. }
  100. }
  101. /* ====================== gemm_mk4_s8_8x12 ===========================*/
  102. MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_8x12);
  103. void gemm_mk4_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin,
  104. int y0, int ymax, int k0, int kmax,
  105. bool transpose) const {
  106. megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix A is not supported");
  107. matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_A(outptr, inptr, ldin, y0, ymax, k0,
  108. kmax);
  109. }
  110. void gemm_mk4_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
  111. int xmax, int k0, int kmax,
  112. bool transpose) const {
  113. megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix B is not supported");
  114. matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax);
  115. }
  116. void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB,
  117. size_t M, size_t N, size_t K, dt_int32* C,
  118. size_t LDC, bool is_first_k, const dt_int32*,
  119. dt_int32*) const {
  120. megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
  121. ((A_dtype.enumv() == DTypeEnum::Int8 &&
  122. C_dtype.enumv() == DTypeEnum::Int32) ||
  123. (A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
  124. C_dtype.enumv() == DTypeEnum::QuantizedS32)),
  125. "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
  126. C_dtype.name());
  127. MEGDNN_MARK_USED_VAR(A_dtype);
  128. MEGDNN_MARK_USED_VAR(B_dtype);
  129. MEGDNN_MARK_USED_VAR(C_dtype);
  130. constexpr size_t A_INTERLEAVE = 8;
  131. constexpr size_t B_INTERLEAVE = 12;
  132. //! K is packed to times of 4
  133. K = round_up<size_t>(K, 4);
  134. const int K8 = (K << 3);
  135. const int K12 = K * 12;
  136. const int K4 = K * 4;
  137. size_t m = 0;
  138. for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
  139. int32_t* output = C + ((m >> 2) * LDC);
  140. size_t n = 0;
  141. const dt_int8* cur_packB = packB;
  142. for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
  143. matmul_mk4_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC,
  144. is_first_k);
  145. output += (B_INTERLEAVE << 2);
  146. cur_packB += K12;
  147. }
  148. for (; n < N; n += 4) {
  149. matmul_mk4_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC,
  150. is_first_k, std::min<size_t>(N - n, 4));
  151. output += 16;
  152. cur_packB += K4;
  153. }
  154. packA += K8;
  155. }
  156. for (; m < M; m += 4) {
  157. int32_t* output = C + ((m >> 2) * LDC);
  158. const dt_int8* cur_packB = packB;
  159. size_t n = 0;
  160. for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
  161. matmul_mk4_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC,
  162. is_first_k);
  163. output += (B_INTERLEAVE << 2);
  164. cur_packB += K12;
  165. }
  166. for (; n < N; n += 4) {
  167. matmul_mk4_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC,
  168. is_first_k, std::min<size_t>(N - n, 4));
  169. output += 16;
  170. cur_packB += K4;
  171. }
  172. packA += K4;
  173. }
  174. }
  175. #endif
  176. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台