You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

strategy.cpp 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. /**
  2. * \file dnn/src/aarch64/matrix_mul/fp32/strategy.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/aarch64/matrix_mul/fp32/strategy.h"
  12. #include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h"
  13. #include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h"
  14. #include "src/common/utils.h"
  15. using namespace megdnn;
  16. using namespace aarch64;
  17. using namespace aarch64::matmul;
  18. MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16);
  19. void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0,
  20. int ymax, int k0, int kmax, bool transpose_A) const {
  21. if (transpose_A) {
  22. matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, kmax);
  23. } else {
  24. matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, kmax);
  25. }
  26. }
  27. void sgemm_4x16::pack_B(float* out, const float* in, int ldin, int x0, int xmax,
  28. int k0, int kmax, bool transpose_B) const {
  29. if (transpose_B) {
  30. matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, kmax);
  31. } else {
  32. matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
  33. }
  34. }
  35. void sgemm_4x16::kern(const float* packA, const float* packB,
  36. size_t M, size_t N, size_t K, float* C, size_t LDC,
  37. bool is_first_k, const float*, float*) const {
  38. megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
  39. A_dtype.enumv() == C_dtype.enumv() &&
  40. A_dtype.enumv() == DTypeEnum::Float32);
  41. MEGDNN_MARK_USED_VAR(A_dtype);
  42. MEGDNN_MARK_USED_VAR(B_dtype);
  43. MEGDNN_MARK_USED_VAR(C_dtype);
  44. constexpr size_t A_INTERLEAVE = 4;
  45. constexpr size_t B_INTERLEAVE = 16;
  46. const int K16 = K * 16;
  47. const int K4 = K * 4;
  48. size_t m = 0;
  49. for (; m < M; m += A_INTERLEAVE) {
  50. float* output = C + (m * LDC);
  51. size_t n = 0;
  52. const float* cur_packB = packB;
  53. for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
  54. matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC, is_first_k,
  55. std::min<size_t>(M - m, 4));
  56. output += B_INTERLEAVE;
  57. cur_packB += K16;
  58. }
  59. for (; n < N; n += 4) {
  60. matmul_general_4x16::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
  61. std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
  62. output += 4;
  63. cur_packB += K4;
  64. }
  65. packA += K4;
  66. }
  67. }
  68. MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12);
  69. void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0,
  70. int ymax, int k0, int kmax, bool transpose_A) const {
  71. if (transpose_A) {
  72. matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0,
  73. kmax);
  74. } else {
  75. matmul_general_8x12::sgemm_8x12_pack_A_n(out, in, ldin, y0, ymax, k0,
  76. kmax);
  77. }
  78. }
  79. void sgemm_8x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax,
  80. int k0, int kmax, bool transpose_B) const {
  81. if (transpose_B) {
  82. matmul_general_8x12::sgemm_8x12_pack_B_t(out, in, ldin, x0, xmax, k0,
  83. kmax);
  84. } else {
  85. matmul_general_8x12::sgemm_8x12_pack_B_n(out, in, ldin, x0, xmax, k0,
  86. kmax);
  87. }
  88. }
  89. void sgemm_8x12::kern(const float* packA, const float* packB,
  90. size_t M, size_t N, size_t K, float* C, size_t LDC,
  91. bool is_first_k, const float*, float*) const {
  92. megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
  93. A_dtype.enumv() == C_dtype.enumv() &&
  94. A_dtype.enumv() == DTypeEnum::Float32);
  95. MEGDNN_MARK_USED_VAR(A_dtype);
  96. MEGDNN_MARK_USED_VAR(B_dtype);
  97. MEGDNN_MARK_USED_VAR(C_dtype);
  98. constexpr size_t A_INTERLEAVE = 8;
  99. constexpr size_t A_INTERLEAVE4 = 4;
  100. constexpr size_t B_INTERLEAVE = 12;
  101. const int K12 = K * 12;
  102. const int K8 = K * 8;
  103. const int K4 = K * 4;
  104. size_t m = 0;
  105. for (; m + A_INTERLEAVE <= M; m += A_INTERLEAVE) {
  106. float* output = C + (m * LDC);
  107. size_t n = 0;
  108. const float* cur_packB = packB;
  109. for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) {
  110. matmul_general_8x12::kern_8x12(packA, cur_packB, K, output, LDC,
  111. is_first_k);
  112. output += B_INTERLEAVE;
  113. cur_packB += K12;
  114. }
  115. for (; n < N; n += 4) {
  116. matmul_general_8x12::kern_8x4(packA, cur_packB, K, output, LDC,
  117. is_first_k,
  118. std::min<size_t>(N - n, 4));
  119. output += 4;
  120. cur_packB += K4;
  121. }
  122. packA += K8;
  123. }
  124. for (; m < M; m += A_INTERLEAVE4) {
  125. float* output = C + (m * LDC);
  126. size_t n = 0;
  127. const float* cur_packB = packB;
  128. for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
  129. matmul_general_8x12::kern_4x12(packA, cur_packB, K, output, LDC,
  130. is_first_k,
  131. std::min<size_t>(M - m, 4));
  132. output += B_INTERLEAVE;
  133. cur_packB += K12;
  134. }
  135. for (; n < N; n += 4) {
  136. matmul_general_8x12::kern_4x4(
  137. packA, cur_packB, K, output, LDC, is_first_k,
  138. std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
  139. output += 4;
  140. cur_packB += K4;
  141. }
  142. packA += K4;
  143. }
  144. }
  145. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台