You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cutlass_convolution_wrapper.cu 39 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659
  1. /**
  2. * \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. // ignore warning of cutlass
  13. #pragma GCC diagnostic push
  14. #pragma GCC diagnostic ignored "-Wunused-parameter"
  15. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  16. #if !MEGDNN_TEGRA_X1
  17. #include "cutlass/convolution/device/convolution.h"
  18. #endif
  19. #include "src/common/opr_param_defs_enumv.cuh"
  20. #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
  21. #pragma GCC diagnostic pop
  22. using namespace megdnn;
  23. using namespace cuda;
  24. using namespace cutlass_wrapper;
  25. /* ================= cutlass kernel wrapper for nchw32 layout ================
  26. */
  27. #if MEGDNN_TEGRA_X1
  28. template <bool NeedLoadFromConstMem>
  29. void megdnn::cuda::cutlass_wrapper::
  30. do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32(
  31. const int8_t* /* d_src */, const int8_t* /* d_filter */,
  32. const int32_t* /* d_bias */, const int8_t* /* d_z */,
  33. int8_t* /* d_dst */, int* /* workspace */,
  34. const convolution::ConvParam& /* param */,
  35. uint32_t /* nonlinear_mode */, float /* alpha */,
  36. float /* beta */, float /* gamma */, float /* scale */,
  37. const GemmCoord& /* threadblock_shape */,
  38. const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
  39. #else
  40. template <bool NeedLoadFromConstMem>
  41. void megdnn::cuda::cutlass_wrapper::
  42. do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32(
  43. const int8_t* d_src, const int8_t* d_filter,
  44. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst,
  45. int* workspace, const convolution::ConvParam& param,
  46. uint32_t nonlinear_mode, float alpha, float beta, float gamma,
  47. float scale, const GemmCoord& threadblock_shape,
  48. const GemmCoord& warp_shape, cudaStream_t stream) {
  49. #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
  50. threadblock_k_, warp_m_, warp_n_, \
  51. warp_k_) \
  52. if (threadblock_shape.m() == threadblock_m_ && \
  53. threadblock_shape.n() == threadblock_n_ && \
  54. threadblock_shape.k() == threadblock_k_ && \
  55. warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
  56. warp_shape.k() == warp_k_) { \
  57. using ThreadBlockShape = \
  58. cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
  59. threadblock_k_>; \
  60. using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
  61. using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; \
  62. using Convolution = cutlass::convolution::device::Convolution< \
  63. int8_t, cutlass::layout::TensorNCxHWx<32>, int8_t, \
  64. cutlass::layout::TensorCxRSKx<32>, ElementOutput, \
  65. cutlass::layout::TensorNCxHWx<32>, int32_t, \
  66. cutlass::layout::TensorNCxHWx<32>, int32_t, \
  67. cutlass::convolution::ConvType::kConvolution, \
  68. cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \
  69. ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
  70. cutlass::convolution::threadblock:: \
  71. ConvolutionNCxHWxThreadblockSwizzle< \
  72. cutlass::convolution::ConvType::kConvolution>, \
  73. 2, 16, 16, NeedLoadFromConstMem>; \
  74. typename Convolution::ConvolutionParameter conv_param{ \
  75. param.n, param.ci, param.co, param.hi, param.wi, \
  76. param.fh, param.fw, param.ho, param.wo, param.sh, \
  77. param.sw, param.ph, param.pw, 1, 1}; \
  78. return cutlass_convolution_wrapper<Convolution>( \
  79. d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \
  80. epilogue, stream); \
  81. }
  82. #define DISPATCH_KERNEL \
  83. DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \
  84. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \
  85. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \
  86. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \
  87. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \
  88. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \
  89. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 32, 16, 64); \
  90. megdnn_assert(false, \
  91. "unsupported threadblock shape (%dx%dx%d) and warp shape " \
  92. "(%dx%dx%d)", \
  93. threadblock_shape.m(), threadblock_shape.n(), \
  94. threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
  95. warp_shape.k());
  96. using ElementOutput = int8_t;
  97. using ElementAccumulator = int32_t;
  98. using ElementBias = int32_t;
  99. using ElementCompute = float;
  100. using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
  101. switch (nonlinear_mode) {
  102. case NonlineMode::IDENTITY: {
  103. using EpilogueOp =
  104. cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
  105. ElementOutput, 8, ElementAccumulator, ElementBias,
  106. ElementCompute>;
  107. typename EpilogueOp::Params epilogue{alpha, beta, gamma};
  108. DISPATCH_KERNEL;
  109. }
  110. case NonlineMode::RELU: {
  111. using EpilogueOp = cutlass::epilogue::thread::
  112. BiasAddLinearCombinationReluClamp<
  113. ElementOutput, 8, ElementAccumulator, ElementBias,
  114. ElementCompute>;
  115. typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
  116. DISPATCH_KERNEL;
  117. }
  118. case NonlineMode::H_SWISH: {
  119. using EpilogueOp = cutlass::epilogue::thread::
  120. BiasAddLinearCombinationHSwishClamp<
  121. ElementOutput, 8, ElementAccumulator, ElementBias,
  122. ElementCompute>;
  123. typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
  124. DISPATCH_KERNEL;
  125. }
  126. default:
  127. megdnn_assert(false,
  128. "unsupported nonlinear mode for conv bias operator");
  129. }
  130. #undef DISPATCH_KERNEL_WITH_TILE_SHAPE
  131. #undef DISPATCH_KERNEL
  132. }
  133. #endif
  134. #define INST(need_load_from_const_mem) \
  135. template void megdnn::cuda::cutlass_wrapper:: \
  136. do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32< \
  137. need_load_from_const_mem>( \
  138. const int8_t* d_src, const int8_t* d_filter, \
  139. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
  140. int* workspace, const convolution::ConvParam& param, \
  141. uint32_t nonlinear_mode, float alpha, float beta, \
  142. float gamma, float scale, \
  143. const GemmCoord& threadblock_shape, \
  144. const GemmCoord& warp_shape, cudaStream_t stream);
  145. INST(true);
  146. INST(false);
  147. #undef INST
  148. /* ==== cutlass kernel wrapper for nchw32 layout and nchw4 output ===== */
  149. #if MEGDNN_TEGRA_X1
  150. template <bool NeedLoadFromConstMem>
  151. void megdnn::cuda::cutlass_wrapper::
  152. do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4(
  153. const int8_t* /* d_src */, const int8_t* /* d_filter */,
  154. const int32_t* /* d_bias */, const int8_t* /* d_z */,
  155. int8_t* /* d_dst */, int* /* workspace */,
  156. const convolution::ConvParam& /* param */,
  157. uint32_t /* nonlinear_mode */, float /* alpha */,
  158. float /* beta */, float /* gamma */, float /* scale */,
  159. const GemmCoord& /* threadblock_shape */,
  160. const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
  161. #else
  162. template <bool NeedLoadFromConstMem>
  163. void megdnn::cuda::cutlass_wrapper::
  164. do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4(
  165. const int8_t* d_src, const int8_t* d_filter,
  166. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst,
  167. int* workspace, const convolution::ConvParam& param,
  168. uint32_t nonlinear_mode, float alpha, float beta, float gamma,
  169. float scale, const GemmCoord& threadblock_shape,
  170. const GemmCoord& warp_shape, cudaStream_t stream) {
  171. #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
  172. threadblock_k_, warp_m_, warp_n_, \
  173. warp_k_) \
  174. if (threadblock_shape.m() == threadblock_m_ && \
  175. threadblock_shape.n() == threadblock_n_ && \
  176. threadblock_shape.k() == threadblock_k_ && \
  177. warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
  178. warp_shape.k() == warp_k_) { \
  179. using ThreadBlockShape = \
  180. cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
  181. threadblock_k_>; \
  182. using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
  183. using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; \
  184. using Convolution = cutlass::convolution::device::Convolution< \
  185. int8_t, cutlass::layout::TensorNCxHWx<32>, int8_t, \
  186. cutlass::layout::TensorCxRSKx<32>, ElementOutput, \
  187. cutlass::layout::TensorNCxHWx<4>, int32_t, \
  188. cutlass::layout::TensorNCxHWx<4>, int32_t, \
  189. cutlass::convolution::ConvType::kConvolution, \
  190. cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \
  191. ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
  192. cutlass::convolution::threadblock:: \
  193. ConvolutionNCxHWxThreadblockSwizzle< \
  194. cutlass::convolution::ConvType::kConvolution>, \
  195. 2, 16, 16, NeedLoadFromConstMem>; \
  196. typename Convolution::ConvolutionParameter conv_param{ \
  197. param.n, param.ci, param.co, param.hi, param.wi, \
  198. param.fh, param.fw, param.ho, param.wo, param.sh, \
  199. param.sw, param.ph, param.pw, 1, 1}; \
  200. return cutlass_convolution_wrapper<Convolution>( \
  201. d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \
  202. epilogue, stream); \
  203. }
  204. #define DISPATCH_KERNEL \
  205. DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \
  206. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \
  207. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \
  208. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \
  209. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \
  210. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \
  211. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 16, 32, 64); \
  212. megdnn_assert(false, \
  213. "unsupported threadblock shape (%dx%dx%d) and warp shape " \
  214. "(%dx%dx%d)", \
  215. threadblock_shape.m(), threadblock_shape.n(), \
  216. threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
  217. warp_shape.k());
  218. using ElementOutput = int8_t;
  219. using ElementAccumulator = int32_t;
  220. using ElementBias = int32_t;
  221. using ElementCompute = float;
  222. using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
  223. switch (nonlinear_mode) {
  224. case NonlineMode::IDENTITY: {
  225. using EpilogueOp =
  226. cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
  227. ElementOutput, 4, ElementAccumulator, ElementBias,
  228. ElementCompute>;
  229. typename EpilogueOp::Params epilogue{alpha, beta, gamma};
  230. DISPATCH_KERNEL;
  231. }
  232. case NonlineMode::RELU: {
  233. using EpilogueOp = cutlass::epilogue::thread::
  234. BiasAddLinearCombinationReluClamp<
  235. ElementOutput, 4, ElementAccumulator, ElementBias,
  236. ElementCompute>;
  237. typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
  238. DISPATCH_KERNEL;
  239. }
  240. case NonlineMode::H_SWISH: {
  241. using EpilogueOp = cutlass::epilogue::thread::
  242. BiasAddLinearCombinationHSwishClamp<
  243. ElementOutput, 4, ElementAccumulator, ElementBias,
  244. ElementCompute>;
  245. typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
  246. DISPATCH_KERNEL;
  247. }
  248. default:
  249. megdnn_assert(false,
  250. "unsupported nonlinear mode for conv bias operator");
  251. }
  252. #undef DISPATCH_KERNEL_WITH_TILE_SHAPE
  253. #undef DISPATCH_KERNEL
  254. }
  255. #endif
  256. #define INST(need_load_from_const_mem) \
  257. template void megdnn::cuda::cutlass_wrapper:: \
  258. do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4< \
  259. need_load_from_const_mem>( \
  260. const int8_t* d_src, const int8_t* d_filter, \
  261. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
  262. int* workspace, const convolution::ConvParam& param, \
  263. uint32_t nonlinear_mode, float alpha, float beta, \
  264. float gamma, float scale, \
  265. const GemmCoord& threadblock_shape, \
  266. const GemmCoord& warp_shape, cudaStream_t stream);
  267. INST(true);
  268. INST(false);
  269. #undef INST
  270. /* ================ cutlass kernel wrapper for nchw4 layout ================= */
  271. #if MEGDNN_TEGRA_X1
  272. template <bool NeedLoadFromConstMem>
  273. void megdnn::cuda::cutlass_wrapper::
  274. do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4(
  275. const int8_t* /* d_src */, const int8_t* /* d_filter */,
  276. const int32_t* /* d_bias */, const int8_t* /* d_z */,
  277. int8_t* /* d_dst */, int* /* workspace */,
  278. const convolution::ConvParam& /* param */,
  279. uint32_t /* nonlinear_mode */, float /* alpha */,
  280. float /* beta */, float /* gamma */, float /* scale */,
  281. const GemmCoord& /* threadblock_shape */,
  282. const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
  283. #else
  284. template <bool NeedLoadFromConstMem>
  285. void megdnn::cuda::cutlass_wrapper::
  286. do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4(
  287. const int8_t* d_src, const int8_t* d_filter,
  288. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst,
  289. int* workspace, const convolution::ConvParam& param,
  290. uint32_t nonlinear_mode, float alpha, float beta, float gamma,
  291. float scale, const GemmCoord& threadblock_shape,
  292. const GemmCoord& warp_shape, cudaStream_t stream) {
  293. #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
  294. threadblock_k_, warp_m_, warp_n_, \
  295. warp_k_, stage_, aligned_) \
  296. if (threadblock_shape.m() == threadblock_m_ && \
  297. threadblock_shape.n() == threadblock_n_ && \
  298. threadblock_shape.k() == threadblock_k_ && \
  299. warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
  300. warp_shape.k() == warp_k_) { \
  301. using ThreadBlockShape = \
  302. cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
  303. threadblock_k_>; \
  304. using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
  305. using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \
  306. using Convolution = cutlass::convolution::device::Convolution< \
  307. int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \
  308. cutlass::layout::TensorCxRSKx<4>, ElementOutput, \
  309. cutlass::layout::TensorNCxHWx<4>, int32_t, \
  310. cutlass::layout::TensorNCxHWx<4>, int32_t, \
  311. cutlass::convolution::ConvType::kConvolution, \
  312. cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \
  313. ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
  314. cutlass::convolution::threadblock:: \
  315. ConvolutionNCxHWxThreadblockSwizzle< \
  316. cutlass::convolution::ConvType::kConvolution>, \
  317. stage_, 4, aligned_, NeedLoadFromConstMem>; \
  318. typename Convolution::ConvolutionParameter conv_param{ \
  319. param.n, param.ci, param.co, param.hi, param.wi, \
  320. param.fh, param.fw, param.ho, param.wo, param.sh, \
  321. param.sw, param.ph, param.pw, 1, 1}; \
  322. return cutlass_convolution_wrapper<Convolution>( \
  323. d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \
  324. epilogue, stream); \
  325. }
  326. #define DISPATCH_KERNEL \
  327. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \
  328. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \
  329. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \
  330. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \
  331. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \
  332. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \
  333. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \
  334. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \
  335. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \
  336. DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \
  337. DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \
  338. megdnn_assert(false, \
  339. "unsupported threadblock shape (%dx%dx%d) and warp shape " \
  340. "(%dx%dx%d)", \
  341. threadblock_shape.m(), threadblock_shape.n(), \
  342. threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
  343. warp_shape.k());
  344. using ElementOutput = int8_t;
  345. using ElementAccumulator = int32_t;
  346. using ElementBias = int32_t;
  347. using ElementCompute = float;
  348. using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
  349. switch (nonlinear_mode) {
  350. case NonlineMode::IDENTITY: {
  351. using EpilogueOp =
  352. cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
  353. ElementOutput, 4, ElementAccumulator, ElementBias,
  354. ElementCompute>;
  355. typename EpilogueOp::Params epilogue{alpha, beta, gamma};
  356. DISPATCH_KERNEL;
  357. }
  358. case NonlineMode::RELU: {
  359. using EpilogueOp = cutlass::epilogue::thread::
  360. BiasAddLinearCombinationReluClamp<
  361. ElementOutput, 4, ElementAccumulator, ElementBias,
  362. ElementCompute>;
  363. typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
  364. DISPATCH_KERNEL;
  365. }
  366. case NonlineMode::H_SWISH: {
  367. using EpilogueOp = cutlass::epilogue::thread::
  368. BiasAddLinearCombinationHSwishClamp<
  369. ElementOutput, 4, ElementAccumulator, ElementBias,
  370. ElementCompute>;
  371. typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
  372. DISPATCH_KERNEL;
  373. }
  374. default:
  375. megdnn_assert(false,
  376. "unsupported nonlinear mode for conv bias operator");
  377. }
  378. #undef DISPATCH_KERNEL_WITH_TILE_SHAPE
  379. #undef DISPATCH_KERNEL
  380. }
  381. #endif
  382. #define INST(need_load_from_const_mem) \
  383. template void megdnn::cuda::cutlass_wrapper:: \
  384. do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4< \
  385. need_load_from_const_mem>( \
  386. const int8_t* d_src, const int8_t* d_filter, \
  387. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
  388. int* workspace, const convolution::ConvParam& param, \
  389. uint32_t nonlinear_mode, float alpha, float beta, \
  390. float gamma, float scale, \
  391. const GemmCoord& threadblock_shape, \
  392. const GemmCoord& warp_shape, cudaStream_t stream);
  393. INST(true);
  394. INST(false);
  395. #undef INST
  396. /* ===== cutlass kernel wrapper for nchw4 layout and nchw output ===== */
  397. #if MEGDNN_TEGRA_X1
  398. template <bool NeedLoadFromConstMem>
  399. void megdnn::cuda::cutlass_wrapper::
  400. do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw(
  401. const int8_t* /* d_src */, const int8_t* /* d_filter */,
  402. const float* /* d_bias */, const float* /* d_z */,
  403. float* /* d_dst */, int* /* workspace */,
  404. const convolution::ConvParam& /* param */,
  405. uint32_t /* nonlinear_mode */, float /* alpha */,
  406. float /* beta */, float /* gamma */, float /* scale */,
  407. const GemmCoord& /* threadblock_shape */,
  408. const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
  409. #else
  410. template <bool NeedLoadFromConstMem>
  411. void megdnn::cuda::cutlass_wrapper::
  412. do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw(
  413. const int8_t* d_src, const int8_t* d_filter,
  414. const float* d_bias, const float* d_z, float* d_dst,
  415. int* workspace, const convolution::ConvParam& param,
  416. uint32_t nonlinear_mode, float alpha, float beta, float gamma,
  417. float scale, const GemmCoord& threadblock_shape,
  418. const GemmCoord& warp_shape, cudaStream_t stream) {
  419. #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
  420. threadblock_k_, warp_m_, warp_n_, \
  421. warp_k_, aligned_) \
  422. if (threadblock_shape.m() == threadblock_m_ && \
  423. threadblock_shape.n() == threadblock_n_ && \
  424. threadblock_shape.k() == threadblock_k_ && \
  425. warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
  426. warp_shape.k() == warp_k_) { \
  427. using ThreadBlockShape = \
  428. cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
  429. threadblock_k_>; \
  430. using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
  431. using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \
  432. using Convolution = cutlass::convolution::device::Convolution< \
  433. int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \
  434. cutlass::layout::TensorCxRSKx<4>, ElementOutput, \
  435. cutlass::layout::TensorNCHW, float, \
  436. cutlass::layout::TensorNCHW, int32_t, \
  437. cutlass::convolution::ConvType::kConvolution, \
  438. cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \
  439. ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
  440. cutlass::convolution::threadblock:: \
  441. ConvolutionNCxHWxThreadblockSwizzle< \
  442. cutlass::convolution::ConvType::kConvolution>, \
  443. 2, 4, aligned_, NeedLoadFromConstMem, \
  444. cutlass::arch::OpMultiplyAdd>; \
  445. typename Convolution::ConvolutionParameter conv_param{ \
  446. param.n, param.ci, param.co, param.hi, param.wi, \
  447. param.fh, param.fw, param.ho, param.wo, param.sh, \
  448. param.sw, param.ph, param.pw, 1, 1}; \
  449. return cutlass_convolution_wrapper<Convolution>( \
  450. d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \
  451. epilogue, stream); \
  452. }
  453. #define DISPATCH_KERNEL \
  454. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 16); \
  455. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 16); \
  456. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 16); \
  457. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 16); \
  458. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 16); \
  459. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 16); \
  460. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 16); \
  461. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 16); \
  462. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 16); \
  463. DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 4); \
  464. megdnn_assert(false, \
  465. "unsupported threadblock shape (%dx%dx%d) and warp shape " \
  466. "(%dx%dx%d)", \
  467. threadblock_shape.m(), threadblock_shape.n(), \
  468. threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
  469. warp_shape.k());
  470. using ElementOutput = float;
  471. using ElementAccumulator = int32_t;
  472. using ElementBias = float;
  473. using ElementCompute = float;
  474. using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
  475. switch (nonlinear_mode) {
  476. case NonlineMode::IDENTITY: {
  477. using EpilogueOp =
  478. cutlass::epilogue::thread::BiasAddLinearCombination<
  479. ElementOutput, 1, ElementAccumulator, ElementBias,
  480. ElementCompute>;
  481. typename EpilogueOp::Params epilogue{alpha, beta, gamma};
  482. DISPATCH_KERNEL;
  483. }
  484. case NonlineMode::RELU: {
  485. using EpilogueOp =
  486. cutlass::epilogue::thread::BiasAddLinearCombinationRelu<
  487. ElementOutput, 1, ElementAccumulator, ElementBias,
  488. ElementCompute>;
  489. typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
  490. DISPATCH_KERNEL;
  491. }
  492. case NonlineMode::H_SWISH: {
  493. using EpilogueOp =
  494. cutlass::epilogue::thread::BiasAddLinearCombinationHSwish<
  495. ElementOutput, 1, ElementAccumulator, ElementBias,
  496. ElementCompute>;
  497. typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
  498. DISPATCH_KERNEL;
  499. }
  500. default:
  501. megdnn_assert(false,
  502. "unsupported nonlinear mode for conv bias operator");
  503. }
  504. #undef DISPATCH_KERNEL_WITH_TILE_SHAPE
  505. #undef DISPATCH_KERNEL
  506. }
  507. #endif
  508. #define INST(need_load_from_const_mem) \
  509. template void megdnn::cuda::cutlass_wrapper:: \
  510. do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw< \
  511. need_load_from_const_mem>( \
  512. const int8_t* d_src, const int8_t* d_filter, \
  513. const float* d_bias, const float* d_z, float* d_dst, \
  514. int* workspace, const convolution::ConvParam& param, \
  515. uint32_t nonlinear_mode, float alpha, float beta, \
  516. float gamma, float scale, \
  517. const GemmCoord& threadblock_shape, \
  518. const GemmCoord& warp_shape, cudaStream_t stream);
  519. INST(true);
  520. INST(false);
  521. #undef INST
  522. /* ====== cutlass kernel wrapper for nchw4 layout and nchw32 output ====== */
  523. #if MEGDNN_TEGRA_X1
  524. template <bool NeedLoadFromConstMem>
  525. void megdnn::cuda::cutlass_wrapper::
  526. do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32(
  527. const int8_t* /* d_src */, const int8_t* /* d_filter */,
  528. const int32_t* /* d_bias */, const int8_t* /* d_z */,
  529. int8_t* /* d_dst */, int* /* workspace */,
  530. const convolution::ConvParam& /* param */,
  531. uint32_t /* nonlinear_mode */, float /* alpha */,
  532. float /* beta */, float /* gamma */, float /* scale */,
  533. const GemmCoord& /* threadblock_shape */,
  534. const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
  535. #else
  536. template <bool NeedLoadFromConstMem>
  537. void megdnn::cuda::cutlass_wrapper::
  538. do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32(
  539. const int8_t* d_src, const int8_t* d_filter,
  540. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst,
  541. int* workspace, const convolution::ConvParam& param,
  542. uint32_t nonlinear_mode, float alpha, float beta, float gamma,
  543. float scale, const GemmCoord& threadblock_shape,
  544. const GemmCoord& warp_shape, cudaStream_t stream) {
  545. #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
  546. threadblock_k_, warp_m_, warp_n_, \
  547. warp_k_, aligned_) \
  548. if (threadblock_shape.m() == threadblock_m_ && \
  549. threadblock_shape.n() == threadblock_n_ && \
  550. threadblock_shape.k() == threadblock_k_ && \
  551. warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
  552. warp_shape.k() == warp_k_) { \
  553. using ThreadBlockShape = \
  554. cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
  555. threadblock_k_>; \
  556. using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
  557. using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \
  558. using Convolution = cutlass::convolution::device::Convolution< \
  559. int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \
  560. cutlass::layout::TensorCxRSKx<4>, ElementOutput, \
  561. cutlass::layout::TensorNCxHWx<32>, int32_t, \
  562. cutlass::layout::TensorNCxHWx<32>, int32_t, \
  563. cutlass::convolution::ConvType::kConvolution, \
  564. cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \
  565. ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
  566. cutlass::convolution::threadblock:: \
  567. ConvolutionNCxHWxThreadblockSwizzle< \
  568. cutlass::convolution::ConvType::kConvolution>, \
  569. 2, 4, aligned_, NeedLoadFromConstMem>; \
  570. typename Convolution::ConvolutionParameter conv_param{ \
  571. param.n, param.ci, param.co, param.hi, param.wi, \
  572. param.fh, param.fw, param.ho, param.wo, param.sh, \
  573. param.sw, param.ph, param.pw, 1, 1}; \
  574. return cutlass_convolution_wrapper<Convolution>( \
  575. d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \
  576. epilogue, stream); \
  577. }
  578. #define DISPATCH_KERNEL \
  579. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 16); \
  580. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 16); \
  581. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 16); \
  582. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 16); \
  583. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 16); \
  584. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 16); \
  585. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 16); \
  586. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 16); \
  587. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 16); \
  588. megdnn_assert(false, \
  589. "unsupported threadblock shape (%dx%dx%d) and warp shape " \
  590. "(%dx%dx%d)", \
  591. threadblock_shape.m(), threadblock_shape.n(), \
  592. threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
  593. warp_shape.k());
  594. using ElementOutput = int8_t;
  595. using ElementAccumulator = int32_t;
  596. using ElementBias = int32_t;
  597. using ElementCompute = float;
  598. using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
  599. switch (nonlinear_mode) {
  600. case NonlineMode::IDENTITY: {
  601. using EpilogueOp =
  602. cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
  603. ElementOutput, 4, ElementAccumulator, ElementBias,
  604. ElementCompute>;
  605. typename EpilogueOp::Params epilogue{alpha, beta, gamma};
  606. DISPATCH_KERNEL;
  607. }
  608. case NonlineMode::RELU: {
  609. using EpilogueOp = cutlass::epilogue::thread::
  610. BiasAddLinearCombinationReluClamp<
  611. ElementOutput, 4, ElementAccumulator, ElementBias,
  612. ElementCompute>;
  613. typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
  614. DISPATCH_KERNEL;
  615. }
  616. case NonlineMode::H_SWISH: {
  617. using EpilogueOp = cutlass::epilogue::thread::
  618. BiasAddLinearCombinationHSwishClamp<
  619. ElementOutput, 4, ElementAccumulator, ElementBias,
  620. ElementCompute>;
  621. typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
  622. DISPATCH_KERNEL;
  623. }
  624. default:
  625. megdnn_assert(false,
  626. "unsupported nonlinear mode for conv bias operator");
  627. }
  628. #undef DISPATCH_KERNEL_WITH_TILE_SHAPE
  629. #undef DISPATCH_KERNEL
  630. }
  631. #endif
  632. #define INST(need_load_from_const_mem) \
  633. template void megdnn::cuda::cutlass_wrapper:: \
  634. do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< \
  635. need_load_from_const_mem>( \
  636. const int8_t* d_src, const int8_t* d_filter, \
  637. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
  638. int* workspace, const convolution::ConvParam& param, \
  639. uint32_t nonlinear_mode, float alpha, float beta, \
  640. float gamma, float scale, \
  641. const GemmCoord& threadblock_shape, \
  642. const GemmCoord& warp_shape, cudaStream_t stream);
  643. INST(true);
  644. INST(false);
  645. #undef INST
  646. // vim: syntax=cuda.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台