You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cutlass_matrix_mul_wrapper.cu 8.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. /**
  2. * \file dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. // ignore warning of cutlass
  13. #include "cuda.h"
  14. #if __CUDACC_VER_MAJOR__ > 9 || \
  15. (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
  16. #pragma GCC diagnostic push
  17. #pragma GCC diagnostic ignored "-Wunused-parameter"
  18. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  19. #include "cutlass/gemm/device/gemm.h"
  20. #include "cutlass/gemm/device/gemm_splitk_parallel.h"
  21. #include "cutlass/gemm/kernel/default_gemv.h"
  22. #include "src/common/opr_param_defs_enumv.cuh"
  23. #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"
  24. #pragma GCC diagnostic pop
  25. using namespace megdnn;
  26. using namespace cuda;
  27. using namespace cutlass_wrapper;
  28. /* ================= cutlass kernel wrapper for f32 matrix mul ================
  29. */
  30. #define DISPATCH(cb) \
  31. cb(64, 256, 8, 32, 64, 8); \
  32. cb(256, 64, 8, 64, 32, 8); \
  33. cb(32, 256, 8, 16, 64, 8); \
  34. cb(256, 32, 8, 64, 16, 8); \
  35. cb(128, 128, 8, 32, 64, 8); \
  36. cb(128, 64, 8, 64, 32, 8); \
  37. cb(64, 128, 8, 32, 64, 8); \
  38. cb(128, 32, 8, 64, 32, 8); \
  39. cb(32, 128, 8, 32, 64, 8); \
  40. cb(64, 64, 8, 32, 64, 8); \
  41. cb(32, 64, 8, 32, 64, 8); \
  42. cb(64, 32, 8, 64, 32, 8); \
  43. cb(32, 32, 8, 32, 32, 8); \
  44. cb(8, 32, 8, 8, 32, 8); \
  45. cb(16, 32, 8, 16, 32, 8); \
  46. cb(16, 64, 8, 16, 64, 8); \
  47. cb(16, 128, 8, 16, 64, 8); \
  48. megdnn_assert(false, \
  49. "unsupported threadblock shape (%dx%dx%d) and warp shape " \
  50. "(%dx%dx%d)", \
  51. threadblock_shape.m(), threadblock_shape.n(), \
  52. threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
  53. warp_shape.k());
  54. void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt(
  55. const float* d_A, bool transpose_A, size_t lda, const float* d_B,
  56. bool transpose_B, size_t ldb, float* d_C, size_t ldc, int* workspace,
  57. GemmCoord const& problem_size, float alpha, float beta,
  58. const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
  59. cudaStream_t stream, int split_k_slices) {
  60. static constexpr int kEpilogueElementsPerAccess = 1;
  61. using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
  62. float, kEpilogueElementsPerAccess, float, float>;
  63. typename EpilogueOp::Params epilogue{alpha, beta};
  64. if (split_k_slices == 1) {
  65. #define cb(threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \
  66. warp_k_) \
  67. if (threadblock_shape.m() == threadblock_m_ && \
  68. threadblock_shape.n() == threadblock_n_ && \
  69. threadblock_shape.k() == threadblock_k_ && \
  70. warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
  71. warp_shape.k() == warp_k_) { \
  72. using ThreadBlockShape = \
  73. cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
  74. threadblock_k_>; \
  75. using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
  76. using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; \
  77. using Gemm = cutlass::gemm::device::Gemm< \
  78. float, LayoutA, float, LayoutB, float, \
  79. cutlass::layout::RowMajor, float, cutlass::arch::OpClassSimt, \
  80. cutlass::arch::Sm50, ThreadBlockShape, WarpShape, \
  81. InstructionShape, EpilogueOp, \
  82. cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, \
  83. 2>; \
  84. return cutlass_matrix_mul_wrapper<Gemm>(d_A, lda, d_B, ldb, d_C, ldc, \
  85. workspace, problem_size, \
  86. epilogue, stream); \
  87. }
  88. if (!transpose_A && !transpose_B) {
  89. using LayoutA = cutlass::layout::RowMajor;
  90. using LayoutB = cutlass::layout::RowMajor;
  91. DISPATCH(cb)
  92. } else if (!transpose_A && transpose_B) {
  93. using LayoutA = cutlass::layout::RowMajor;
  94. using LayoutB = cutlass::layout::ColumnMajor;
  95. DISPATCH(cb)
  96. } else if (transpose_A && !transpose_B) {
  97. using LayoutA = cutlass::layout::ColumnMajor;
  98. using LayoutB = cutlass::layout::RowMajor;
  99. DISPATCH(cb)
  100. } else {
  101. megdnn_assert(transpose_A && transpose_B);
  102. using LayoutA = cutlass::layout::ColumnMajor;
  103. using LayoutB = cutlass::layout::ColumnMajor;
  104. DISPATCH(cb)
  105. }
  106. #undef cb
  107. } else {
  108. #define cb(threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \
  109. warp_k_) \
  110. if (threadblock_shape.m() == threadblock_m_ && \
  111. threadblock_shape.n() == threadblock_n_ && \
  112. threadblock_shape.k() == threadblock_k_ && \
  113. warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
  114. warp_shape.k() == warp_k_) { \
  115. using ThreadBlockShape = \
  116. cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
  117. threadblock_k_>; \
  118. using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
  119. using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; \
  120. using Gemm = cutlass::gemm::device::GemmSplitKParallel< \
  121. float, LayoutA, float, LayoutB, float, \
  122. cutlass::layout::RowMajor, float, cutlass::arch::OpClassSimt, \
  123. cutlass::arch::Sm50, ThreadBlockShape, WarpShape, \
  124. InstructionShape, EpilogueOp>; \
  125. return cutlass_matrix_mul_wrapper<Gemm>( \
  126. d_A, lda, d_B, ldb, d_C, ldc, workspace, problem_size, \
  127. epilogue, stream, split_k_slices); \
  128. }
  129. if (!transpose_A && !transpose_B) {
  130. using LayoutA = cutlass::layout::RowMajor;
  131. using LayoutB = cutlass::layout::RowMajor;
  132. DISPATCH(cb)
  133. } else if (!transpose_A && transpose_B) {
  134. using LayoutA = cutlass::layout::RowMajor;
  135. using LayoutB = cutlass::layout::ColumnMajor;
  136. DISPATCH(cb)
  137. } else if (transpose_A && !transpose_B) {
  138. using LayoutA = cutlass::layout::ColumnMajor;
  139. using LayoutB = cutlass::layout::RowMajor;
  140. DISPATCH(cb)
  141. } else {
  142. megdnn_assert(transpose_A && transpose_B);
  143. using LayoutA = cutlass::layout::ColumnMajor;
  144. using LayoutB = cutlass::layout::ColumnMajor;
  145. DISPATCH(cb)
  146. }
  147. #undef cb
  148. }
  149. }
  150. #undef DISPATCH
  151. #endif
  152. // vim: syntax=cuda.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台