You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cutlass_convolution_wrapper.cu 8.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. /**
  2. * \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. // ignore warning of cutlass
  13. #pragma GCC diagnostic push
  14. #pragma GCC diagnostic ignored "-Wunused-parameter"
  15. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  16. #if !MEGDNN_TEGRA_X1
  17. #include "cutlass/convolution/device/convolution.h"
  18. #endif
  19. #include "src/common/opr_param_defs_enumv.cuh"
  20. #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
  21. #pragma GCC diagnostic pop
  22. using namespace megdnn;
  23. using namespace cuda;
  24. using namespace cutlass_wrapper;
  25. #if MEGDNN_TEGRA_X1
  26. template <bool NeedLoadFromConstMem>
  27. void megdnn::cuda::cutlass_wrapper::
  28. do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32(
  29. const int8_t* /* d_src */, const int8_t* /* d_filter */,
  30. const int32_t* /* d_bias */, const int8_t* /* d_z */,
  31. int8_t* /* d_dst */, int* /* workspace */,
  32. const convolution::ConvParam& /* param */,
  33. uint32_t /* nonlinear_mode */, float /* alpha */,
  34. float /* beta */, float /* gamma */, float /* scale */,
  35. const GemmCoord& /* threadblock_shape */,
  36. const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
  37. #else
  38. template <bool NeedLoadFromConstMem>
  39. void megdnn::cuda::cutlass_wrapper::
  40. do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32(
  41. const int8_t* d_src, const int8_t* d_filter,
  42. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst,
  43. int* workspace, const convolution::ConvParam& param,
  44. uint32_t nonlinear_mode, float alpha, float beta, float gamma,
  45. float scale, const GemmCoord& threadblock_shape,
  46. const GemmCoord& warp_shape, cudaStream_t stream) {
  47. #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
  48. threadblock_k_, warp_m_, warp_n_, \
  49. warp_k_) \
  50. if (threadblock_shape.m() == threadblock_m_ && \
  51. threadblock_shape.n() == threadblock_n_ && \
  52. threadblock_shape.k() == threadblock_k_ && \
  53. warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
  54. warp_shape.k() == warp_k_) { \
  55. using ThreadBlockShape = \
  56. cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
  57. threadblock_k_>; \
  58. using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
  59. using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; \
  60. using Convolution = cutlass::convolution::device::Convolution< \
  61. int8_t, cutlass::layout::TensorNCxHWx<32>, int8_t, \
  62. cutlass::layout::TensorCxRSKx<32>, ElementOutput, \
  63. cutlass::layout::TensorNCxHWx<32>, int32_t, \
  64. cutlass::layout::TensorNCxHWx<32>, int32_t, \
  65. cutlass::convolution::ConvType::kConvolution, \
  66. cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \
  67. ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
  68. cutlass::convolution::threadblock:: \
  69. ConvolutionNCxHWxThreadblockSwizzle< \
  70. cutlass::convolution::ConvType::kConvolution>, \
  71. 2, 16, 16, NeedLoadFromConstMem>; \
  72. typename Convolution::ConvolutionParameter conv_param{ \
  73. param.n, param.ci, param.co, param.hi, param.wi, \
  74. param.fh, param.fw, param.ho, param.wo, param.sh, \
  75. param.sw, param.ph, param.pw, 1, 1}; \
  76. return cutlass_convolution_wrapper<Convolution>( \
  77. d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \
  78. epilogue, stream); \
  79. }
  80. #define DISPATCH_KERNEL \
  81. DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \
  82. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \
  83. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \
  84. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \
  85. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \
  86. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \
  87. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 32, 16, 64); \
  88. megdnn_assert(false, \
  89. "unsupported threadblock shape (%dx%dx%d) and warp shape " \
  90. "(%dx%dx%d)", \
  91. threadblock_shape.m(), threadblock_shape.n(), \
  92. threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
  93. warp_shape.k());
  94. using ElementOutput = int8_t;
  95. using ElementAccumulator = int32_t;
  96. using ElementBias = int32_t;
  97. using ElementCompute = float;
  98. using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
  99. switch (nonlinear_mode) {
  100. case NonlineMode::IDENTITY: {
  101. using EpilogueOp =
  102. cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
  103. ElementOutput, 8, ElementAccumulator, ElementBias,
  104. ElementCompute>;
  105. typename EpilogueOp::Params epilogue{alpha, beta, gamma};
  106. DISPATCH_KERNEL;
  107. }
  108. case NonlineMode::RELU: {
  109. using EpilogueOp = cutlass::epilogue::thread::
  110. BiasAddLinearCombinationReluClamp<
  111. ElementOutput, 8, ElementAccumulator, ElementBias,
  112. ElementCompute>;
  113. typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
  114. DISPATCH_KERNEL;
  115. }
  116. case NonlineMode::H_SWISH: {
  117. using EpilogueOp = cutlass::epilogue::thread::
  118. BiasAddLinearCombinationHSwishClamp<
  119. ElementOutput, 8, ElementAccumulator, ElementBias,
  120. ElementCompute>;
  121. typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
  122. DISPATCH_KERNEL;
  123. }
  124. default:
  125. megdnn_assert(false,
  126. "unsupported nonlinear mode for conv bias operator");
  127. }
  128. #undef DISPATCH_KERNEL_WITH_TILE_SHAPE
  129. #undef DISPATCH_KERNEL
  130. }
  131. #endif
  132. #define INST(need_load_from_const_mem) \
  133. template void megdnn::cuda::cutlass_wrapper:: \
  134. do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32< \
  135. need_load_from_const_mem>( \
  136. const int8_t* d_src, const int8_t* d_filter, \
  137. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
  138. int* workspace, const convolution::ConvParam& param, \
  139. uint32_t nonlinear_mode, float alpha, float beta, \
  140. float gamma, float scale, \
  141. const GemmCoord& threadblock_shape, \
  142. const GemmCoord& warp_shape, cudaStream_t stream);
  143. INST(true);
  144. INST(false);
  145. #undef INST
  146. // vim: syntax=cuda.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台