You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.cpp 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. /**
  2. * \file dnn/src/cuda/matrix_mul/algos.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "./algos.h"
  13. #include <cuda.h>
  14. #include "src/common/algo_base.h"
  15. #include "src/cuda/conv_bias/algo.h"
  16. #include "src/cuda/conv_bias/opr_impl.h"
  17. #include "src/cuda/utils.h"
  18. #if CUDA_VERSION >= 10010
  19. #include <cublasLt.h>
  20. #endif
  21. using namespace megdnn;
  22. using namespace cuda;
  23. MatrixMulForwardImpl::AlgoPack::AlgoPack() {
  24. all_algos.push_back(&cublas);
  25. #if CUDA_VERSION >= 10000
  26. all_algos.push_back(&wmma_uint4x4x32);
  27. #endif
  28. #if CUDA_VERSION >= 10010
  29. all_algos.push_back(&cublas_lt);
  30. #endif
  31. #if !MEGDNN_DISABLE_FLOAT16
  32. all_algos.push_back(&bfloat16);
  33. #endif
  34. #if CUDA_VERSION >= 9020
  35. fill_cutlass_algos();
  36. for (auto&& algo : simt_float32) {
  37. all_algos.push_back(&algo);
  38. }
  39. for (auto&& algo : simt_float32_split_k) {
  40. all_algos.push_back(&algo);
  41. }
  42. for (auto&& algo : simt_float32_gemv_batched_strided) {
  43. all_algos.push_back(&algo);
  44. }
  45. #if CUDA_VERSION >= 10010
  46. for (auto&& algo : tensorop_float16) {
  47. all_algos.push_back(&algo);
  48. }
  49. for (auto&& algo : tensorop_float16_split_k) {
  50. all_algos.push_back(&algo);
  51. }
  52. #endif
  53. #endif
  54. all_algos.push_back(&naive);
  55. std::vector<cudnnConvolutionFwdAlgo_t> cudnn_conv_enum;
  56. for (auto&& algo : CudnnAlgoPack::conv_fwd_algos()) {
  57. cudnn_conv_enum.push_back(algo.first);
  58. }
  59. for (auto&& algo : cudnn_conv_enum) {
  60. conv1x1.push_back(AlgoConv1X1CUDNN(algo));
  61. }
  62. for (size_t i = 0; i < conv1x1.size(); ++i) {
  63. all_algos.push_back(&conv1x1[i]);
  64. }
  65. for (auto&& algo : all_algos) {
  66. m_all_algos_map.emplace(algo->info().desc, algo);
  67. }
  68. }
  69. #if CUDA_VERSION >= 9020
  70. void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() {
  71. using AlgoParam = AlgoCutlassMatrixMulBase::AlgoParam;
  72. simt_float32.emplace_back(AlgoParam{64, 256, 8, 32, 64, 8});
  73. simt_float32.emplace_back(AlgoParam{256, 64, 8, 64, 32, 8});
  74. simt_float32.emplace_back(AlgoParam{32, 256, 8, 16, 64, 8});
  75. simt_float32.emplace_back(AlgoParam{256, 32, 8, 64, 16, 8});
  76. simt_float32.emplace_back(AlgoParam{128, 128, 8, 32, 64, 8});
  77. simt_float32.emplace_back(AlgoParam{128, 64, 8, 64, 32, 8});
  78. simt_float32.emplace_back(AlgoParam{64, 128, 8, 32, 64, 8});
  79. simt_float32.emplace_back(AlgoParam{128, 32, 8, 64, 32, 8});
  80. simt_float32.emplace_back(AlgoParam{32, 128, 8, 32, 64, 8});
  81. simt_float32.emplace_back(AlgoParam{64, 64, 8, 32, 64, 8});
  82. simt_float32.emplace_back(AlgoParam{32, 64, 8, 32, 64, 8});
  83. simt_float32.emplace_back(AlgoParam{64, 32, 8, 64, 32, 8});
  84. simt_float32.emplace_back(AlgoParam{32, 32, 8, 32, 32, 8});
  85. simt_float32.emplace_back(AlgoParam{8, 32, 8, 8, 32, 8});
  86. simt_float32.emplace_back(AlgoParam{16, 32, 8, 16, 32, 8});
  87. simt_float32.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8});
  88. simt_float32.emplace_back(AlgoParam{16, 128, 8, 16, 64, 8});
  89. simt_float32_split_k.emplace_back(AlgoParam{64, 256, 8, 32, 64, 8});
  90. simt_float32_split_k.emplace_back(AlgoParam{256, 64, 8, 64, 32, 8});
  91. simt_float32_split_k.emplace_back(AlgoParam{32, 256, 8, 16, 64, 8});
  92. simt_float32_split_k.emplace_back(AlgoParam{256, 32, 8, 64, 16, 8});
  93. simt_float32_split_k.emplace_back(AlgoParam{128, 128, 8, 32, 64, 8});
  94. simt_float32_split_k.emplace_back(AlgoParam{128, 64, 8, 64, 32, 8});
  95. simt_float32_split_k.emplace_back(AlgoParam{64, 128, 8, 32, 64, 8});
  96. simt_float32_split_k.emplace_back(AlgoParam{128, 32, 8, 64, 32, 8});
  97. simt_float32_split_k.emplace_back(AlgoParam{32, 128, 8, 32, 64, 8});
  98. simt_float32_split_k.emplace_back(AlgoParam{64, 64, 8, 32, 64, 8});
  99. simt_float32_split_k.emplace_back(AlgoParam{32, 64, 8, 32, 64, 8});
  100. simt_float32_split_k.emplace_back(AlgoParam{64, 32, 8, 64, 32, 8});
  101. simt_float32_split_k.emplace_back(AlgoParam{32, 32, 8, 32, 32, 8});
  102. simt_float32_split_k.emplace_back(AlgoParam{8, 32, 8, 8, 32, 8});
  103. simt_float32_split_k.emplace_back(AlgoParam{16, 32, 8, 16, 32, 8});
  104. simt_float32_split_k.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8});
  105. simt_float32_split_k.emplace_back(AlgoParam{16, 128, 8, 16, 64, 8});
  106. simt_float32_gemv_batched_strided.emplace_back(128);
  107. simt_float32_gemv_batched_strided.emplace_back(64);
  108. simt_float32_gemv_batched_strided.emplace_back(32);
  109. #define FOREACH_CUTLASS_MATMUL_MMA_SM70_SHAPES(cb) \
  110. cb(256, 128, 32, 64, 64, 32, 8, 8, 4); \
  111. cb(128, 256, 32, 64, 64, 32, 8, 8, 4); \
  112. cb(128, 128, 32, 64, 64, 32, 8, 8, 4);
  113. #define FOREACH_CUTLASS_MATMUL_MMA_SM75_SHAPES(cb) \
  114. cb(256, 128, 32, 64, 64, 32, 16, 8, 8); \
  115. cb(128, 256, 32, 64, 64, 32, 16, 8, 8); \
  116. cb(128, 128, 32, 64, 64, 32, 16, 8, 8);
  117. #define cb(...) \
  118. tensorop_float16.emplace_back(AlgoParam{__VA_ARGS__}); \
  119. tensorop_float16_split_k.emplace_back(AlgoParam{__VA_ARGS__});
  120. #if CUDA_VERSION >= 10010
  121. FOREACH_CUTLASS_MATMUL_MMA_SM70_SHAPES(cb)
  122. #endif
  123. #if CUDA_VERSION >= 10020
  124. FOREACH_CUTLASS_MATMUL_MMA_SM75_SHAPES(cb)
  125. #endif
  126. #undef cb
  127. #undef FOREACH_CUTLASS_MATMUL_MMA_SM70_SHAPES
  128. #undef FOREACH_CUTLASS_MATMUL_MMA_SM75_SHAPES
  129. }
  130. #endif
  131. MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack;
  132. MEGDNN_DEF_GET_ALGO_FROM_DESC(MatrixMulForwardImpl)
  133. MatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs(
  134. MatrixMulForwardImpl* o, const TensorLayout& A, const TensorLayout& B,
  135. const TensorLayout& C)
  136. : opr{o}, layout_a{A}, layout_b{B}, layout_c{C} {}
  137. MatrixMulForwardImpl::AlgoBase::ExecArgs::ExecArgs(
  138. MatrixMulForwardImpl* opr, _megdnn_tensor_in A, _megdnn_tensor_in B,
  139. _megdnn_tensor_out C, _megdnn_workspace workspace)
  140. : SizeArgs(opr, A.layout, B.layout, C.layout),
  141. tensor_a{A},
  142. tensor_b{B},
  143. tensor_c{C},
  144. workspace{workspace} {}
  145. std::string MatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const {
  146. auto&& param = opr->param();
  147. size_t m = layout_a.shape[0], n = layout_b.shape[1],
  148. k = layout_a.shape[param.transposeA ? 0 : 1];
  149. MEGDNN_MARK_USED_VAR(m);
  150. MEGDNN_MARK_USED_VAR(n);
  151. MEGDNN_MARK_USED_VAR(k);
  152. return ssprintf(
  153. "A={%zux%zu},B={%zux%zu},C={%zux%zu},Transpose A=%d,Transpose "
  154. "B=%d,ldA=%zu,ldB=%zu,ldC=%zu",
  155. m, k, k, n, m, n, param.transposeA, param.transposeB, layout_a.stride[0],
  156. layout_b.stride[0], layout_c.stride[0]);
  157. }
  158. // vim: syntax=cpp.doxygen