You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cutlass_convolution_wrapper.cu 62 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050
  1. /**
  2. * \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. // ignore warning of cutlass
  13. #pragma GCC diagnostic push
  14. #pragma GCC diagnostic ignored "-Wunused-parameter"
  15. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  16. #if !MEGDNN_TEGRA_X1
  17. #include "cutlass/convolution/device/convolution.h"
  18. #endif
  19. #include "src/common/opr_param_defs_enumv.cuh"
  20. #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
  21. #pragma GCC diagnostic pop
  22. using namespace megdnn;
  23. using namespace cuda;
  24. using namespace cutlass_wrapper;
  25. /* ====== cutlass kernel wrapper for int8 nchw32 layout ====== */
  26. #if MEGDNN_TEGRA_X1
  27. template <bool NeedLoadFromConstMem>
  28. void megdnn::cuda::cutlass_wrapper::
  29. do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32(
  30. const int8_t* /* d_src */, const int8_t* /* d_filter */,
  31. const int32_t* /* d_bias */, const int8_t* /* d_z */,
  32. int8_t* /* d_dst */, int* /* workspace */,
  33. const convolution::ConvParam& /* param */,
  34. uint32_t /* nonlinear_mode */, float /* alpha */,
  35. float /* beta */, float /* gamma */, float /* scale */,
  36. const GemmCoord& /* threadblock_shape */,
  37. const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
  38. #else
  39. template <bool NeedLoadFromConstMem>
  40. void megdnn::cuda::cutlass_wrapper::
  41. do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32(
  42. const int8_t* d_src, const int8_t* d_filter,
  43. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst,
  44. int* workspace, const convolution::ConvParam& param,
  45. uint32_t nonlinear_mode, float alpha, float beta, float gamma,
  46. float scale, const GemmCoord& threadblock_shape,
  47. const GemmCoord& warp_shape, cudaStream_t stream) {
  48. #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
  49. threadblock_k_, warp_m_, warp_n_, \
  50. warp_k_) \
  51. if (threadblock_shape.m() == threadblock_m_ && \
  52. threadblock_shape.n() == threadblock_n_ && \
  53. threadblock_shape.k() == threadblock_k_ && \
  54. warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
  55. warp_shape.k() == warp_k_) { \
  56. using ThreadBlockShape = \
  57. cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
  58. threadblock_k_>; \
  59. using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
  60. using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; \
  61. using Convolution = cutlass::conv::device::Convolution< \
  62. int8_t, cutlass::layout::TensorNCxHWx<32>, int8_t, \
  63. cutlass::layout::TensorCxRSKx<32>, ElementOutput, \
  64. cutlass::layout::TensorNCxHWx<32>, int32_t, \
  65. cutlass::layout::TensorNCxHWx<32>, int32_t, \
  66. cutlass::conv::ConvType::kConvolution, \
  67. cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \
  68. ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
  69. cutlass::conv::threadblock:: \
  70. ConvolutionFpropNCxHWxThreadblockSwizzle, \
  71. 2, 16, 16, NeedLoadFromConstMem>; \
  72. typename Convolution::ConvolutionParameter conv_param( \
  73. param.n, param.hi, param.wi, param.ci, param.co, param.fh, \
  74. param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \
  75. param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \
  76. return cutlass_convolution_wrapper<Convolution>( \
  77. d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \
  78. epilogue, stream); \
  79. }
  80. #define DISPATCH_KERNEL \
  81. DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \
  82. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \
  83. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \
  84. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \
  85. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \
  86. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \
  87. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 32, 16, 64); \
  88. megdnn_assert(false, \
  89. "unsupported threadblock shape (%dx%dx%d) and warp shape " \
  90. "(%dx%dx%d)", \
  91. threadblock_shape.m(), threadblock_shape.n(), \
  92. threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
  93. warp_shape.k());
  94. using ElementOutput = int8_t;
  95. using ElementAccumulator = int32_t;
  96. using ElementBias = int32_t;
  97. using ElementCompute = float;
  98. using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
  99. switch (nonlinear_mode) {
  100. case NonlineMode::IDENTITY: {
  101. using EpilogueOp =
  102. cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
  103. ElementOutput, 8, ElementAccumulator, ElementBias,
  104. ElementCompute>;
  105. typename EpilogueOp::Params epilogue{alpha, beta, gamma};
  106. DISPATCH_KERNEL;
  107. }
  108. case NonlineMode::RELU: {
  109. using EpilogueOp = cutlass::epilogue::thread::
  110. BiasAddLinearCombinationReluClamp<
  111. ElementOutput, 8, ElementAccumulator, ElementBias,
  112. ElementCompute>;
  113. typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
  114. DISPATCH_KERNEL;
  115. }
  116. case NonlineMode::H_SWISH: {
  117. using EpilogueOp = cutlass::epilogue::thread::
  118. BiasAddLinearCombinationHSwishClamp<
  119. ElementOutput, 8, ElementAccumulator, ElementBias,
  120. ElementCompute>;
  121. typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
  122. DISPATCH_KERNEL;
  123. }
  124. default:
  125. megdnn_assert(false,
  126. "unsupported nonlinear mode for conv bias operator");
  127. }
  128. #undef DISPATCH_KERNEL_WITH_TILE_SHAPE
  129. #undef DISPATCH_KERNEL
  130. }
  131. #endif
  132. #define INST(need_load_from_const_mem) \
  133. template void megdnn::cuda::cutlass_wrapper:: \
  134. do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32< \
  135. need_load_from_const_mem>( \
  136. const int8_t* d_src, const int8_t* d_filter, \
  137. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
  138. int* workspace, const convolution::ConvParam& param, \
  139. uint32_t nonlinear_mode, float alpha, float beta, \
  140. float gamma, float scale, \
  141. const GemmCoord& threadblock_shape, \
  142. const GemmCoord& warp_shape, cudaStream_t stream);
  143. INST(true);
  144. INST(false);
  145. #undef INST
  146. /* ===== cutlass kernel wrapper for int8 nchw32 layout and nchw4 output ===== */
  147. #if MEGDNN_TEGRA_X1
  148. template <bool NeedLoadFromConstMem>
  149. void megdnn::cuda::cutlass_wrapper::
  150. do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4(
  151. const int8_t* /* d_src */, const int8_t* /* d_filter */,
  152. const int32_t* /* d_bias */, const int8_t* /* d_z */,
  153. int8_t* /* d_dst */, int* /* workspace */,
  154. const convolution::ConvParam& /* param */,
  155. uint32_t /* nonlinear_mode */, float /* alpha */,
  156. float /* beta */, float /* gamma */, float /* scale */,
  157. const GemmCoord& /* threadblock_shape */,
  158. const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
  159. #else
  160. template <bool NeedLoadFromConstMem>
  161. void megdnn::cuda::cutlass_wrapper::
  162. do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4(
  163. const int8_t* d_src, const int8_t* d_filter,
  164. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst,
  165. int* workspace, const convolution::ConvParam& param,
  166. uint32_t nonlinear_mode, float alpha, float beta, float gamma,
  167. float scale, const GemmCoord& threadblock_shape,
  168. const GemmCoord& warp_shape, cudaStream_t stream) {
  169. #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
  170. threadblock_k_, warp_m_, warp_n_, \
  171. warp_k_) \
  172. if (threadblock_shape.m() == threadblock_m_ && \
  173. threadblock_shape.n() == threadblock_n_ && \
  174. threadblock_shape.k() == threadblock_k_ && \
  175. warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
  176. warp_shape.k() == warp_k_) { \
  177. using ThreadBlockShape = \
  178. cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
  179. threadblock_k_>; \
  180. using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
  181. using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; \
  182. using Convolution = cutlass::conv::device::Convolution< \
  183. int8_t, cutlass::layout::TensorNCxHWx<32>, int8_t, \
  184. cutlass::layout::TensorCxRSKx<32>, ElementOutput, \
  185. cutlass::layout::TensorNCxHWx<4>, int32_t, \
  186. cutlass::layout::TensorNCxHWx<4>, int32_t, \
  187. cutlass::conv::ConvType::kConvolution, \
  188. cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \
  189. ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
  190. cutlass::conv::threadblock:: \
  191. ConvolutionFpropNCxHWxThreadblockSwizzle, \
  192. 2, 16, 16, NeedLoadFromConstMem>; \
  193. typename Convolution::ConvolutionParameter conv_param( \
  194. param.n, param.hi, param.wi, param.ci, param.co, param.fh, \
  195. param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \
  196. param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \
  197. return cutlass_convolution_wrapper<Convolution>( \
  198. d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \
  199. epilogue, stream); \
  200. }
  201. #define DISPATCH_KERNEL \
  202. DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \
  203. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \
  204. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \
  205. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \
  206. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \
  207. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \
  208. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 16, 32, 64); \
  209. megdnn_assert(false, \
  210. "unsupported threadblock shape (%dx%dx%d) and warp shape " \
  211. "(%dx%dx%d)", \
  212. threadblock_shape.m(), threadblock_shape.n(), \
  213. threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
  214. warp_shape.k());
  215. using ElementOutput = int8_t;
  216. using ElementAccumulator = int32_t;
  217. using ElementBias = int32_t;
  218. using ElementCompute = float;
  219. using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
  220. switch (nonlinear_mode) {
  221. case NonlineMode::IDENTITY: {
  222. using EpilogueOp =
  223. cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
  224. ElementOutput, 4, ElementAccumulator, ElementBias,
  225. ElementCompute>;
  226. typename EpilogueOp::Params epilogue{alpha, beta, gamma};
  227. DISPATCH_KERNEL;
  228. }
  229. case NonlineMode::RELU: {
  230. using EpilogueOp = cutlass::epilogue::thread::
  231. BiasAddLinearCombinationReluClamp<
  232. ElementOutput, 4, ElementAccumulator, ElementBias,
  233. ElementCompute>;
  234. typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
  235. DISPATCH_KERNEL;
  236. }
  237. case NonlineMode::H_SWISH: {
  238. using EpilogueOp = cutlass::epilogue::thread::
  239. BiasAddLinearCombinationHSwishClamp<
  240. ElementOutput, 4, ElementAccumulator, ElementBias,
  241. ElementCompute>;
  242. typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
  243. DISPATCH_KERNEL;
  244. }
  245. default:
  246. megdnn_assert(false,
  247. "unsupported nonlinear mode for conv bias operator");
  248. }
  249. #undef DISPATCH_KERNEL_WITH_TILE_SHAPE
  250. #undef DISPATCH_KERNEL
  251. }
  252. #endif
  253. #define INST(need_load_from_const_mem) \
  254. template void megdnn::cuda::cutlass_wrapper:: \
  255. do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4< \
  256. need_load_from_const_mem>( \
  257. const int8_t* d_src, const int8_t* d_filter, \
  258. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
  259. int* workspace, const convolution::ConvParam& param, \
  260. uint32_t nonlinear_mode, float alpha, float beta, \
  261. float gamma, float scale, \
  262. const GemmCoord& threadblock_shape, \
  263. const GemmCoord& warp_shape, cudaStream_t stream);
  264. INST(true);
  265. INST(false);
  266. #undef INST
  267. /* ====== cutlass kernel wrapper for int8 nchw4 layout ====== */
  268. #if MEGDNN_TEGRA_X1
  269. template <bool NeedLoadFromConstMem>
  270. void megdnn::cuda::cutlass_wrapper::
  271. do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4(
  272. const int8_t* /* d_src */, const int8_t* /* d_filter */,
  273. const int32_t* /* d_bias */, const int8_t* /* d_z */,
  274. int8_t* /* d_dst */, int* /* workspace */,
  275. const convolution::ConvParam& /* param */,
  276. uint32_t /* nonlinear_mode */, float /* alpha */,
  277. float /* beta */, float /* gamma */, float /* scale */,
  278. const GemmCoord& /* threadblock_shape */,
  279. const GemmCoord& /* warp_shape */, int /* stages */,
  280. cudaStream_t /* stream */) {}
  281. #else
  282. template <bool NeedLoadFromConstMem>
  283. void megdnn::cuda::cutlass_wrapper::
  284. do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4(
  285. const int8_t* d_src, const int8_t* d_filter,
  286. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst,
  287. int* workspace, const convolution::ConvParam& param,
  288. uint32_t nonlinear_mode, float alpha, float beta, float gamma,
  289. float scale, const GemmCoord& threadblock_shape,
  290. const GemmCoord& warp_shape, int stages, cudaStream_t stream) {
  291. #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
  292. threadblock_k_, warp_m_, warp_n_, \
  293. warp_k_, stage_, aligned_) \
  294. if (threadblock_shape.m() == threadblock_m_ && \
  295. threadblock_shape.n() == threadblock_n_ && \
  296. threadblock_shape.k() == threadblock_k_ && \
  297. warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
  298. warp_shape.k() == warp_k_ && stages == stage_) { \
  299. using ThreadBlockShape = \
  300. cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
  301. threadblock_k_>; \
  302. using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
  303. using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \
  304. using Convolution = cutlass::conv::device::Convolution< \
  305. int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \
  306. cutlass::layout::TensorCxRSKx<4>, ElementOutput, \
  307. cutlass::layout::TensorNCxHWx<4>, int32_t, \
  308. cutlass::layout::TensorNCxHWx<4>, int32_t, \
  309. cutlass::conv::ConvType::kConvolution, \
  310. cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \
  311. ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
  312. cutlass::conv::threadblock:: \
  313. ConvolutionFpropNCxHWxThreadblockSwizzle, \
  314. stage_, 4, aligned_, NeedLoadFromConstMem>; \
  315. typename Convolution::ConvolutionParameter conv_param( \
  316. param.n, param.hi, param.wi, param.ci, param.co, param.fh, \
  317. param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \
  318. param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \
  319. return cutlass_convolution_wrapper<Convolution>( \
  320. d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \
  321. epilogue, stream); \
  322. }
  323. #define DISPATCH_KERNEL \
  324. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \
  325. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \
  326. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \
  327. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \
  328. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \
  329. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \
  330. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \
  331. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \
  332. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \
  333. DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \
  334. DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \
  335. megdnn_assert(false, \
  336. "unsupported threadblock shape (%dx%dx%d) and warp shape " \
  337. "(%dx%dx%d)", \
  338. threadblock_shape.m(), threadblock_shape.n(), \
  339. threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
  340. warp_shape.k());
  341. using ElementOutput = int8_t;
  342. using ElementAccumulator = int32_t;
  343. using ElementBias = int32_t;
  344. using ElementCompute = float;
  345. using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
  346. switch (nonlinear_mode) {
  347. case NonlineMode::IDENTITY: {
  348. using EpilogueOp =
  349. cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
  350. ElementOutput, 4, ElementAccumulator, ElementBias,
  351. ElementCompute>;
  352. typename EpilogueOp::Params epilogue{alpha, beta, gamma};
  353. DISPATCH_KERNEL;
  354. }
  355. case NonlineMode::RELU: {
  356. using EpilogueOp = cutlass::epilogue::thread::
  357. BiasAddLinearCombinationReluClamp<
  358. ElementOutput, 4, ElementAccumulator, ElementBias,
  359. ElementCompute>;
  360. typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
  361. DISPATCH_KERNEL;
  362. }
  363. case NonlineMode::H_SWISH: {
  364. using EpilogueOp = cutlass::epilogue::thread::
  365. BiasAddLinearCombinationHSwishClamp<
  366. ElementOutput, 4, ElementAccumulator, ElementBias,
  367. ElementCompute>;
  368. typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
  369. DISPATCH_KERNEL;
  370. }
  371. default:
  372. megdnn_assert(false,
  373. "unsupported nonlinear mode for conv bias operator");
  374. }
  375. #undef DISPATCH_KERNEL_WITH_TILE_SHAPE
  376. #undef DISPATCH_KERNEL
  377. }
  378. #endif
  379. #define INST(need_load_from_const_mem) \
  380. template void megdnn::cuda::cutlass_wrapper:: \
  381. do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4< \
  382. need_load_from_const_mem>( \
  383. const int8_t* d_src, const int8_t* d_filter, \
  384. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
  385. int* workspace, const convolution::ConvParam& param, \
  386. uint32_t nonlinear_mode, float alpha, float beta, \
  387. float gamma, float scale, \
  388. const GemmCoord& threadblock_shape, \
  389. const GemmCoord& warp_shape, int stages, \
  390. cudaStream_t stream);
  391. INST(true);
  392. INST(false);
  393. #undef INST
  394. /* ====== cutlass kernel wrapper for int8 nchw4 layout and nchw output ====== */
  395. #if MEGDNN_TEGRA_X1
  396. template <bool NeedLoadFromConstMem>
  397. void megdnn::cuda::cutlass_wrapper::
  398. do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw(
  399. const int8_t* /* d_src */, const int8_t* /* d_filter */,
  400. const float* /* d_bias */, const float* /* d_z */,
  401. float* /* d_dst */, int* /* workspace */,
  402. const convolution::ConvParam& /* param */,
  403. uint32_t /* nonlinear_mode */, float /* alpha */,
  404. float /* beta */, float /* gamma */, float /* scale */,
  405. const GemmCoord& /* threadblock_shape */,
  406. const GemmCoord& /* warp_shape */, int /* stages */,
  407. cudaStream_t /* stream */) {}
  408. #else
  409. template <bool NeedLoadFromConstMem>
  410. void megdnn::cuda::cutlass_wrapper::
  411. do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw(
  412. const int8_t* d_src, const int8_t* d_filter,
  413. const float* d_bias, const float* d_z, float* d_dst,
  414. int* workspace, const convolution::ConvParam& param,
  415. uint32_t nonlinear_mode, float alpha, float beta, float gamma,
  416. float scale, const GemmCoord& threadblock_shape,
  417. const GemmCoord& warp_shape, int stages, cudaStream_t stream) {
  418. #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
  419. threadblock_k_, warp_m_, warp_n_, \
  420. warp_k_, stages_, aligned_) \
  421. if (threadblock_shape.m() == threadblock_m_ && \
  422. threadblock_shape.n() == threadblock_n_ && \
  423. threadblock_shape.k() == threadblock_k_ && \
  424. warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
  425. warp_shape.k() == warp_k_ && stages == stages_) { \
  426. using ThreadBlockShape = \
  427. cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
  428. threadblock_k_>; \
  429. using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
  430. using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \
  431. using Convolution = cutlass::conv::device::Convolution< \
  432. int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \
  433. cutlass::layout::TensorCxRSKx<4>, ElementOutput, \
  434. cutlass::layout::TensorNCHW, float, \
  435. cutlass::layout::TensorNCHW, int32_t, \
  436. cutlass::conv::ConvType::kConvolution, \
  437. cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \
  438. ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
  439. cutlass::conv::threadblock:: \
  440. ConvolutionFpropNCxHWxThreadblockSwizzle, \
  441. stages_, 4, aligned_, NeedLoadFromConstMem, \
  442. cutlass::arch::OpMultiplyAdd>; \
  443. typename Convolution::ConvolutionParameter conv_param( \
  444. param.n, param.hi, param.wi, param.ci, param.co, param.fh, \
  445. param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \
  446. param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \
  447. return cutlass_convolution_wrapper<Convolution>( \
  448. d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \
  449. epilogue, stream); \
  450. }
  451. #define DISPATCH_KERNEL \
  452. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \
  453. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \
  454. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \
  455. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \
  456. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \
  457. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \
  458. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \
  459. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \
  460. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \
  461. DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \
  462. DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \
  463. megdnn_assert(false, \
  464. "unsupported threadblock shape (%dx%dx%d) and warp shape " \
  465. "(%dx%dx%d)", \
  466. threadblock_shape.m(), threadblock_shape.n(), \
  467. threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
  468. warp_shape.k());
  469. using ElementOutput = float;
  470. using ElementAccumulator = int32_t;
  471. using ElementBias = float;
  472. using ElementCompute = float;
  473. using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
  474. switch (nonlinear_mode) {
  475. case NonlineMode::IDENTITY: {
  476. using EpilogueOp =
  477. cutlass::epilogue::thread::BiasAddLinearCombination<
  478. ElementOutput, 1, ElementAccumulator, ElementBias,
  479. ElementCompute>;
  480. typename EpilogueOp::Params epilogue{alpha, beta, gamma};
  481. DISPATCH_KERNEL;
  482. }
  483. case NonlineMode::RELU: {
  484. using EpilogueOp =
  485. cutlass::epilogue::thread::BiasAddLinearCombinationRelu<
  486. ElementOutput, 1, ElementAccumulator, ElementBias,
  487. ElementCompute>;
  488. typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
  489. DISPATCH_KERNEL;
  490. }
  491. case NonlineMode::H_SWISH: {
  492. using EpilogueOp =
  493. cutlass::epilogue::thread::BiasAddLinearCombinationHSwish<
  494. ElementOutput, 1, ElementAccumulator, ElementBias,
  495. ElementCompute>;
  496. typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
  497. DISPATCH_KERNEL;
  498. }
  499. default:
  500. megdnn_assert(false,
  501. "unsupported nonlinear mode for conv bias operator");
  502. }
  503. #undef DISPATCH_KERNEL_WITH_TILE_SHAPE
  504. #undef DISPATCH_KERNEL
  505. }
  506. #endif
  507. #define INST(need_load_from_const_mem) \
  508. template void megdnn::cuda::cutlass_wrapper:: \
  509. do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw< \
  510. need_load_from_const_mem>( \
  511. const int8_t* d_src, const int8_t* d_filter, \
  512. const float* d_bias, const float* d_z, float* d_dst, \
  513. int* workspace, const convolution::ConvParam& param, \
  514. uint32_t nonlinear_mode, float alpha, float beta, \
  515. float gamma, float scale, \
  516. const GemmCoord& threadblock_shape, \
  517. const GemmCoord& warp_shape, int stages, \
  518. cudaStream_t stream);
  519. INST(true);
  520. INST(false);
  521. #undef INST
  522. /* ===== cutlass kernel wrapper for int8 nchw4 layout and nchw32 output ===== */
  523. #if MEGDNN_TEGRA_X1
  524. template <bool NeedLoadFromConstMem>
  525. void megdnn::cuda::cutlass_wrapper::
  526. do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32(
  527. const int8_t* /* d_src */, const int8_t* /* d_filter */,
  528. const int32_t* /* d_bias */, const int8_t* /* d_z */,
  529. int8_t* /* d_dst */, int* /* workspace */,
  530. const convolution::ConvParam& /* param */,
  531. uint32_t /* nonlinear_mode */, float /* alpha */,
  532. float /* beta */, float /* gamma */, float /* scale */,
  533. const GemmCoord& /* threadblock_shape */,
  534. const GemmCoord& /* warp_shape */, int /* stages */,
  535. cudaStream_t /* stream */) {}
  536. #else
  537. template <bool NeedLoadFromConstMem>
  538. void megdnn::cuda::cutlass_wrapper::
  539. do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32(
  540. const int8_t* d_src, const int8_t* d_filter,
  541. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst,
  542. int* workspace, const convolution::ConvParam& param,
  543. uint32_t nonlinear_mode, float alpha, float beta, float gamma,
  544. float scale, const GemmCoord& threadblock_shape,
  545. const GemmCoord& warp_shape, int stages, cudaStream_t stream) {
  546. #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
  547. threadblock_k_, warp_m_, warp_n_, \
  548. warp_k_, stages_, aligned_) \
  549. if (threadblock_shape.m() == threadblock_m_ && \
  550. threadblock_shape.n() == threadblock_n_ && \
  551. threadblock_shape.k() == threadblock_k_ && \
  552. warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
  553. warp_shape.k() == warp_k_ && stages == stages_) { \
  554. using ThreadBlockShape = \
  555. cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
  556. threadblock_k_>; \
  557. using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
  558. using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \
  559. using Convolution = cutlass::conv::device::Convolution< \
  560. int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \
  561. cutlass::layout::TensorCxRSKx<4>, ElementOutput, \
  562. cutlass::layout::TensorNCxHWx<32>, int32_t, \
  563. cutlass::layout::TensorNCxHWx<32>, int32_t, \
  564. cutlass::conv::ConvType::kConvolution, \
  565. cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \
  566. ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
  567. cutlass::conv::threadblock:: \
  568. ConvolutionFpropNCxHWxThreadblockSwizzle, \
  569. stages_, 4, aligned_, NeedLoadFromConstMem>; \
  570. typename Convolution::ConvolutionParameter conv_param( \
  571. param.n, param.hi, param.wi, param.ci, param.co, param.fh, \
  572. param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \
  573. param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \
  574. return cutlass_convolution_wrapper<Convolution>( \
  575. d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \
  576. epilogue, stream); \
  577. }
  578. #define DISPATCH_KERNEL \
  579. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \
  580. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \
  581. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \
  582. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \
  583. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \
  584. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \
  585. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \
  586. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \
  587. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \
  588. megdnn_assert(false, \
  589. "unsupported threadblock shape (%dx%dx%d) and warp shape " \
  590. "(%dx%dx%d)", \
  591. threadblock_shape.m(), threadblock_shape.n(), \
  592. threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
  593. warp_shape.k());
  594. using ElementOutput = int8_t;
  595. using ElementAccumulator = int32_t;
  596. using ElementBias = int32_t;
  597. using ElementCompute = float;
  598. using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
  599. switch (nonlinear_mode) {
  600. case NonlineMode::IDENTITY: {
  601. using EpilogueOp =
  602. cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
  603. ElementOutput, 4, ElementAccumulator, ElementBias,
  604. ElementCompute>;
  605. typename EpilogueOp::Params epilogue{alpha, beta, gamma};
  606. DISPATCH_KERNEL;
  607. }
  608. case NonlineMode::RELU: {
  609. using EpilogueOp = cutlass::epilogue::thread::
  610. BiasAddLinearCombinationReluClamp<
  611. ElementOutput, 4, ElementAccumulator, ElementBias,
  612. ElementCompute>;
  613. typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
  614. DISPATCH_KERNEL;
  615. }
  616. case NonlineMode::H_SWISH: {
  617. using EpilogueOp = cutlass::epilogue::thread::
  618. BiasAddLinearCombinationHSwishClamp<
  619. ElementOutput, 4, ElementAccumulator, ElementBias,
  620. ElementCompute>;
  621. typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
  622. DISPATCH_KERNEL;
  623. }
  624. default:
  625. megdnn_assert(false,
  626. "unsupported nonlinear mode for conv bias operator");
  627. }
  628. #undef DISPATCH_KERNEL_WITH_TILE_SHAPE
  629. #undef DISPATCH_KERNEL
  630. }
  631. #endif
  632. #define INST(need_load_from_const_mem) \
  633. template void megdnn::cuda::cutlass_wrapper:: \
  634. do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< \
  635. need_load_from_const_mem>( \
  636. const int8_t* d_src, const int8_t* d_filter, \
  637. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
  638. int* workspace, const convolution::ConvParam& param, \
  639. uint32_t nonlinear_mode, float alpha, float beta, \
  640. float gamma, float scale, \
  641. const GemmCoord& threadblock_shape, \
  642. const GemmCoord& warp_shape, int stages, \
  643. cudaStream_t stream);
  644. INST(true);
  645. INST(false);
  646. #undef INST
  647. /* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */
  648. #if MEGDNN_TEGRA_X1
  649. template <bool NeedLoadFromConstMem>
  650. void megdnn::cuda::cutlass_wrapper::
  651. do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64(
  652. const int8_t* /* d_src */, const int8_t* /* d_filter */,
  653. const int32_t* /* d_bias */, const int8_t* /* d_z */,
  654. int8_t* /* d_dst */, int* /* workspace */,
  655. const convolution::ConvParam& /* param */,
  656. uint32_t /* nonlinear_mode */, float /* alpha */,
  657. float /* beta */, float /* gamma */, float /* scale */,
  658. const GemmCoord& /* threadblock_shape */,
  659. const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
  660. #else
  661. template <bool NeedLoadFromConstMem>
  662. void megdnn::cuda::cutlass_wrapper::
  663. do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64(
  664. const int8_t* d_src, const int8_t* d_filter,
  665. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst,
  666. int* workspace, const convolution::ConvParam& param,
  667. uint32_t nonlinear_mode, float alpha, float beta, float gamma,
  668. float scale, const GemmCoord& threadblock_shape,
  669. const GemmCoord& warp_shape, cudaStream_t stream) {
  670. #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
  671. threadblock_k_, warp_m_, warp_n_, \
  672. warp_k_) \
  673. if (threadblock_shape.m() == threadblock_m_ && \
  674. threadblock_shape.n() == threadblock_n_ && \
  675. threadblock_shape.k() == threadblock_k_ && \
  676. warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
  677. warp_shape.k() == warp_k_) { \
  678. using ThreadBlockShape = \
  679. cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
  680. threadblock_k_>; \
  681. using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
  682. using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \
  683. using Convolution = cutlass::conv::device::Convolution< \
  684. cutlass::int4b_t, cutlass::layout::TensorNCxHWx<64>, \
  685. cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \
  686. ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \
  687. cutlass::layout::TensorNCxHWx<64>, int32_t, \
  688. cutlass::conv::ConvType::kConvolution, \
  689. cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \
  690. ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
  691. cutlass::conv::threadblock:: \
  692. ConvolutionFpropNCxHWxThreadblockSwizzle, \
  693. 2, 32, 32, NeedLoadFromConstMem>; \
  694. typename Convolution::ConvolutionParameter conv_param( \
  695. param.n, param.hi, param.wi, param.ci, param.co, param.fh, \
  696. param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \
  697. param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \
  698. return cutlass_convolution_wrapper<Convolution>( \
  699. reinterpret_cast<const cutlass::int4b_t*>(d_src), \
  700. reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \
  701. reinterpret_cast<const cutlass::int4b_t*>(d_z), \
  702. reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, \
  703. conv_param, epilogue, stream); \
  704. }
  705. #define DISPATCH_KERNEL \
  706. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128); \
  707. DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 128, 64, 64, 128); \
  708. megdnn_assert(false, \
  709. "unsupported threadblock shape (%dx%dx%d) and warp shape " \
  710. "(%dx%dx%d)", \
  711. threadblock_shape.m(), threadblock_shape.n(), \
  712. threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
  713. warp_shape.k());
  714. using ElementOutput = cutlass::int4b_t;
  715. using ElementAccumulator = int32_t;
  716. using ElementBias = int32_t;
  717. using ElementCompute = float;
  718. using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
  719. switch (nonlinear_mode) {
  720. case NonlineMode::IDENTITY: {
  721. using EpilogueOp =
  722. cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
  723. ElementOutput, 16, ElementAccumulator, ElementBias,
  724. ElementCompute>;
  725. typename EpilogueOp::Params epilogue{alpha, beta, gamma};
  726. DISPATCH_KERNEL;
  727. }
  728. case NonlineMode::RELU: {
  729. using EpilogueOp = cutlass::epilogue::thread::
  730. BiasAddLinearCombinationReluClamp<
  731. ElementOutput, 16, ElementAccumulator, ElementBias,
  732. ElementCompute>;
  733. typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
  734. DISPATCH_KERNEL;
  735. }
  736. case NonlineMode::H_SWISH: {
  737. using EpilogueOp = cutlass::epilogue::thread::
  738. BiasAddLinearCombinationHSwishClamp<
  739. ElementOutput, 16, ElementAccumulator, ElementBias,
  740. ElementCompute>;
  741. typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
  742. DISPATCH_KERNEL;
  743. }
  744. default:
  745. megdnn_assert(false,
  746. "unsupported nonlinear mode for conv bias operator");
  747. }
  748. #undef DISPATCH_KERNEL_WITH_TILE_SHAPE
  749. #undef DISPATCH_KERNEL
  750. }
  751. #endif
  752. #define INST(need_load_from_const_mem) \
  753. template void megdnn::cuda::cutlass_wrapper:: \
  754. do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< \
  755. need_load_from_const_mem>( \
  756. const int8_t* d_src, const int8_t* d_filter, \
  757. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
  758. int* workspace, const convolution::ConvParam& param, \
  759. uint32_t nonlinear_mode, float alpha, float beta, \
  760. float gamma, float scale, \
  761. const GemmCoord& threadblock_shape, \
  762. const GemmCoord& warp_shape, cudaStream_t stream);
  763. INST(true);
  764. #undef INST
  765. /* ====== cutlass kernel wrapper for uint4 x int4 nchw64 layout ====== */
  766. #if MEGDNN_TEGRA_X1
  767. template <bool NeedLoadFromConstMem>
  768. void megdnn::cuda::cutlass_wrapper::
  769. do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64(
  770. const uint8_t* /* d_src */, const int8_t* /* d_filter */,
  771. const int32_t* /* d_bias */, const uint8_t* /* d_z */,
  772. uint8_t* /* d_dst */, int* /* workspace */,
  773. const convolution::ConvParam& /* param */,
  774. uint32_t /* nonlinear_mode */, float /* alpha */,
  775. float /* beta */, float /* gamma */, float /* delta */,
  776. float /* theta */, float /* scale */,
  777. uint8_t /* src_zero_point */,
  778. const GemmCoord& /* threadblock_shape */,
  779. const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
  780. #else
  781. template <bool NeedLoadFromConstMem>
  782. void megdnn::cuda::cutlass_wrapper::
  783. do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64(
  784. const uint8_t* d_src, const int8_t* d_filter,
  785. const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst,
  786. int* workspace, const convolution::ConvParam& param,
  787. uint32_t nonlinear_mode, float alpha, float beta, float gamma,
  788. float delta, float theta, float scale, uint8_t src_zero_point,
  789. const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
  790. cudaStream_t stream) {
  791. #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
  792. threadblock_k_, warp_m_, warp_n_, \
  793. warp_k_) \
  794. if (threadblock_shape.m() == threadblock_m_ && \
  795. threadblock_shape.n() == threadblock_n_ && \
  796. threadblock_shape.k() == threadblock_k_ && \
  797. warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
  798. warp_shape.k() == warp_k_) { \
  799. using ThreadBlockShape = \
  800. cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
  801. threadblock_k_>; \
  802. using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
  803. using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \
  804. using Convolution = cutlass::conv::device::Convolution< \
  805. cutlass::uint4b_t, cutlass::layout::TensorNCxHWx<64>, \
  806. cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \
  807. ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \
  808. cutlass::layout::TensorNCxHWx<64>, int32_t, \
  809. cutlass::conv::ConvType::kConvolution, \
  810. cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \
  811. ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
  812. cutlass::conv::threadblock:: \
  813. ConvolutionFpropNCxHWxThreadblockSwizzle, \
  814. 2, 32, 32, NeedLoadFromConstMem>; \
  815. typename Convolution::ConvolutionParameter conv_param( \
  816. param.n, param.hi, param.wi, param.ci, param.co, param.fh, \
  817. param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \
  818. param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \
  819. return cutlass_convolution_wrapper<Convolution>( \
  820. reinterpret_cast<const cutlass::uint4b_t*>(d_src), \
  821. reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \
  822. reinterpret_cast<const cutlass::uint4b_t*>(d_z), \
  823. reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \
  824. conv_param, epilogue, stream, {src_zero_point}); \
  825. }
  826. #define DISPATCH_KERNEL \
  827. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128); \
  828. DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 128, 64, 64, 128); \
  829. megdnn_assert(false, \
  830. "unsupported threadblock shape (%dx%dx%d) and warp shape " \
  831. "(%dx%dx%d)", \
  832. threadblock_shape.m(), threadblock_shape.n(), \
  833. threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
  834. warp_shape.k());
  835. using ElementOutput = cutlass::uint4b_t;
  836. using ElementAccumulator = int32_t;
  837. using ElementBias = int32_t;
  838. using ElementCompute = float;
  839. using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
  840. switch (nonlinear_mode) {
  841. case NonlineMode::IDENTITY: {
  842. using EpilogueOp =
  843. cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
  844. ElementOutput, 16, ElementAccumulator, ElementBias,
  845. ElementCompute>;
  846. typename EpilogueOp::Params epilogue{alpha, beta, gamma,
  847. delta + theta};
  848. DISPATCH_KERNEL;
  849. }
  850. case NonlineMode::RELU: {
  851. using EpilogueOp = cutlass::epilogue::thread::
  852. BiasAddLinearCombinationReluClamp<
  853. ElementOutput, 16, ElementAccumulator, ElementBias,
  854. ElementCompute>;
  855. typename EpilogueOp::Params epilogue{alpha, beta, gamma,
  856. 0, delta, theta};
  857. DISPATCH_KERNEL;
  858. }
  859. case NonlineMode::H_SWISH: {
  860. using EpilogueOp = cutlass::epilogue::thread::
  861. BiasAddLinearCombinationHSwishClamp<
  862. ElementOutput, 16, ElementAccumulator, ElementBias,
  863. ElementCompute>;
  864. typename EpilogueOp::Params epilogue{alpha, beta, gamma,
  865. scale, delta, theta};
  866. DISPATCH_KERNEL;
  867. }
  868. default:
  869. megdnn_assert(false,
  870. "unsupported nonlinear mode for conv bias operator");
  871. }
  872. #undef DISPATCH_KERNEL_WITH_TILE_SHAPE
  873. #undef DISPATCH_KERNEL
  874. }
  875. #endif
  876. #define INST(need_load_from_const_mem) \
  877. template void megdnn::cuda::cutlass_wrapper:: \
  878. do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< \
  879. need_load_from_const_mem>( \
  880. const uint8_t* d_src, const int8_t* d_filter, \
  881. const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \
  882. int* workspace, const convolution::ConvParam& param, \
  883. uint32_t nonlinear_mode, float alpha, float beta, \
  884. float gamma, float delta, float theta, float scale, \
  885. uint8_t src_zero_point, \
  886. const GemmCoord& threadblock_shape, \
  887. const GemmCoord& warp_shape, cudaStream_t stream);
  888. INST(true);
  889. #undef INST
  890. /* ===== cutlass kernel wrapper for nchw4 layout and nhwc output ===== */
  891. #if MEGDNN_TEGRA_X1
  892. template <bool signedness>
  893. void megdnn::cuda::cutlass_wrapper::
  894. do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc(
  895. const int8_t* /* d_src */, const int8_t* /* d_filter */,
  896. const int32_t* /* d_bias */, const int8_t* /* d_z */,
  897. int8_t* /* d_dst */, int* /* workspace */,
  898. const convolution::ConvParam& /* param */,
  899. uint32_t /* nonlinear_mode */, float /* alpha */,
  900. float /* beta */, float /* gamma */, float /* delta */,
  901. float /* theta */, float /* scale */,
  902. const GemmCoord& /* threadblock_shape */,
  903. const GemmCoord& /* warp_shape */, int /* stages */,
  904. cudaStream_t /* stream */) {}
  905. #else
  906. template <bool signedness>
  907. void megdnn::cuda::cutlass_wrapper::
  908. do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc(
  909. const int8_t* d_src, const int8_t* d_filter,
  910. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst,
  911. int* workspace, const convolution::ConvParam& param,
  912. uint32_t nonlinear_mode, float alpha, float beta, float gamma,
  913. float delta, float theta, float scale,
  914. const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
  915. int stages, cudaStream_t stream) {
  916. #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
  917. threadblock_k_, warp_m_, warp_n_, \
  918. warp_k_, stages_, aligned_) \
  919. if (threadblock_shape.m() == threadblock_m_ && \
  920. threadblock_shape.n() == threadblock_n_ && \
  921. threadblock_shape.k() == threadblock_k_ && \
  922. warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
  923. warp_shape.k() == warp_k_ && stages == stages_) { \
  924. using ThreadBlockShape = \
  925. cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
  926. threadblock_k_>; \
  927. using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
  928. using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \
  929. using Convolution = cutlass::conv::device::Convolution< \
  930. int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \
  931. cutlass::layout::TensorCxRSKx<4>, ElementOutput, \
  932. cutlass::layout::TensorNHWC, int32_t, \
  933. cutlass::layout::TensorNHWC, int32_t, \
  934. cutlass::conv::ConvType::kConvolution, \
  935. cutlass::arch::OpClassSimt, cutlass::arch::Sm75, \
  936. ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
  937. cutlass::conv::threadblock:: \
  938. ConvolutionFpropNCxHWxThreadblockSwizzle, \
  939. stages_, 4, aligned_, true, \
  940. cutlass::arch::OpMultiplyAddSaturate>; \
  941. typename Convolution::ConvolutionParameter conv_param( \
  942. param.n, param.hi, param.wi, param.ci, param.co, param.fh, \
  943. param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \
  944. param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \
  945. return cutlass_convolution_wrapper<Convolution>( \
  946. d_src, d_filter, d_bias, \
  947. reinterpret_cast<const ElementOutput*>(d_z), \
  948. reinterpret_cast<ElementOutput*>(d_dst), workspace, \
  949. conv_param, epilogue, stream); \
  950. }
  951. #define DISPATCH_KERNEL \
  952. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \
  953. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \
  954. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \
  955. DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \
  956. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \
  957. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \
  958. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \
  959. DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \
  960. DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \
  961. DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \
  962. DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \
  963. megdnn_assert(false, \
  964. "unsupported threadblock shape (%dx%dx%d) and warp shape " \
  965. "(%dx%dx%d)", \
  966. threadblock_shape.m(), threadblock_shape.n(), \
  967. threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
  968. warp_shape.k());
  969. using ElementOutput = cutlass::integer_subbyte<4, signedness>;
  970. using ElementAccumulator = int32_t;
  971. using ElementBias = int32_t;
  972. using ElementCompute = float;
  973. using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
  974. switch (nonlinear_mode) {
  975. case NonlineMode::IDENTITY: {
  976. using EpilogueOp =
  977. cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
  978. ElementOutput, 8, ElementAccumulator, ElementBias,
  979. ElementCompute>;
  980. typename EpilogueOp::Params epilogue{alpha, beta, gamma,
  981. delta + theta};
  982. DISPATCH_KERNEL;
  983. }
  984. case NonlineMode::RELU: {
  985. using EpilogueOp = cutlass::epilogue::thread::
  986. BiasAddLinearCombinationReluClamp<
  987. ElementOutput, 8, ElementAccumulator, ElementBias,
  988. ElementCompute>;
  989. typename EpilogueOp::Params epilogue{alpha, beta, gamma,
  990. 0, delta, theta};
  991. DISPATCH_KERNEL;
  992. }
  993. case NonlineMode::H_SWISH: {
  994. using EpilogueOp = cutlass::epilogue::thread::
  995. BiasAddLinearCombinationHSwishClamp<
  996. ElementOutput, 8, ElementAccumulator, ElementBias,
  997. ElementCompute>;
  998. typename EpilogueOp::Params epilogue{alpha, beta, gamma,
  999. scale, delta, theta};
  1000. DISPATCH_KERNEL;
  1001. }
  1002. default:
  1003. megdnn_assert(false,
  1004. "unsupported nonlinear mode for conv bias operator");
  1005. }
  1006. #undef DISPATCH_KERNEL_WITH_TILE_SHAPE
  1007. #undef DISPATCH_KERNEL
  1008. }
  1009. #endif
  1010. #define INST(signedness) \
  1011. template void megdnn::cuda::cutlass_wrapper:: \
  1012. do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc<signedness>( \
  1013. const int8_t* d_src, const int8_t* d_filter, \
  1014. const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
  1015. int* workspace, const convolution::ConvParam& param, \
  1016. uint32_t nonlinear_mode, float alpha, float beta, \
  1017. float gamma, float delta, float theta, float scale, \
  1018. const GemmCoord& threadblock_shape, \
  1019. const GemmCoord& warp_shape, int stages, \
  1020. cudaStream_t stream);
  1021. INST(true);
  1022. INST(false);
  1023. #undef INST
  1024. // vim: syntax=cuda.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台