|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659 |
- /**
- * \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
- * implied.
- */
- // ignore warning of cutlass
- #pragma GCC diagnostic push
- #pragma GCC diagnostic ignored "-Wunused-parameter"
- #pragma GCC diagnostic ignored "-Wstrict-aliasing"
-
- #if !MEGDNN_TEGRA_X1
- #include "cutlass/convolution/device/convolution.h"
- #endif
- #include "src/common/opr_param_defs_enumv.cuh"
- #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
- #pragma GCC diagnostic pop
-
- using namespace megdnn;
- using namespace cuda;
- using namespace cutlass_wrapper;
-
- /* ================= cutlass kernel wrapper for nchw32 layout ================
- */
- #if MEGDNN_TEGRA_X1
- template <bool NeedLoadFromConstMem>
- void megdnn::cuda::cutlass_wrapper::
- do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32(
- const int8_t* /* d_src */, const int8_t* /* d_filter */,
- const int32_t* /* d_bias */, const int8_t* /* d_z */,
- int8_t* /* d_dst */, int* /* workspace */,
- const convolution::ConvParam& /* param */,
- uint32_t /* nonlinear_mode */, float /* alpha */,
- float /* beta */, float /* gamma */, float /* scale */,
- const GemmCoord& /* threadblock_shape */,
- const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
- #else
- template <bool NeedLoadFromConstMem>
- void megdnn::cuda::cutlass_wrapper::
- do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32(
- const int8_t* d_src, const int8_t* d_filter,
- const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst,
- int* workspace, const convolution::ConvParam& param,
- uint32_t nonlinear_mode, float alpha, float beta, float gamma,
- float scale, const GemmCoord& threadblock_shape,
- const GemmCoord& warp_shape, cudaStream_t stream) {
- #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
- threadblock_k_, warp_m_, warp_n_, \
- warp_k_) \
- if (threadblock_shape.m() == threadblock_m_ && \
- threadblock_shape.n() == threadblock_n_ && \
- threadblock_shape.k() == threadblock_k_ && \
- warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
- warp_shape.k() == warp_k_) { \
- using ThreadBlockShape = \
- cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
- threadblock_k_>; \
- using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
- using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; \
- using Convolution = cutlass::convolution::device::Convolution< \
- int8_t, cutlass::layout::TensorNCxHWx<32>, int8_t, \
- cutlass::layout::TensorCxRSKx<32>, ElementOutput, \
- cutlass::layout::TensorNCxHWx<32>, int32_t, \
- cutlass::layout::TensorNCxHWx<32>, int32_t, \
- cutlass::convolution::ConvType::kConvolution, \
- cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \
- ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
- cutlass::convolution::threadblock:: \
- ConvolutionNCxHWxThreadblockSwizzle< \
- cutlass::convolution::ConvType::kConvolution>, \
- 2, 16, 16, NeedLoadFromConstMem>; \
- typename Convolution::ConvolutionParameter conv_param{ \
- param.n, param.ci, param.co, param.hi, param.wi, \
- param.fh, param.fw, param.ho, param.wo, param.sh, \
- param.sw, param.ph, param.pw, 1, 1}; \
- return cutlass_convolution_wrapper<Convolution>( \
- d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \
- epilogue, stream); \
- }
- #define DISPATCH_KERNEL \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 32, 16, 64); \
- megdnn_assert(false, \
- "unsupported threadblock shape (%dx%dx%d) and warp shape " \
- "(%dx%dx%d)", \
- threadblock_shape.m(), threadblock_shape.n(), \
- threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
- warp_shape.k());
- using ElementOutput = int8_t;
- using ElementAccumulator = int32_t;
- using ElementBias = int32_t;
- using ElementCompute = float;
- using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
- switch (nonlinear_mode) {
- case NonlineMode::IDENTITY: {
- using EpilogueOp =
- cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
- ElementOutput, 8, ElementAccumulator, ElementBias,
- ElementCompute>;
- typename EpilogueOp::Params epilogue{alpha, beta, gamma};
- DISPATCH_KERNEL;
- }
- case NonlineMode::RELU: {
- using EpilogueOp = cutlass::epilogue::thread::
- BiasAddLinearCombinationReluClamp<
- ElementOutput, 8, ElementAccumulator, ElementBias,
- ElementCompute>;
- typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
- DISPATCH_KERNEL;
- }
- case NonlineMode::H_SWISH: {
- using EpilogueOp = cutlass::epilogue::thread::
- BiasAddLinearCombinationHSwishClamp<
- ElementOutput, 8, ElementAccumulator, ElementBias,
- ElementCompute>;
- typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
- DISPATCH_KERNEL;
- }
- default:
- megdnn_assert(false,
- "unsupported nonlinear mode for conv bias operator");
- }
- #undef DISPATCH_KERNEL_WITH_TILE_SHAPE
- #undef DISPATCH_KERNEL
- }
- #endif
-
- #define INST(need_load_from_const_mem) \
- template void megdnn::cuda::cutlass_wrapper:: \
- do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32< \
- need_load_from_const_mem>( \
- const int8_t* d_src, const int8_t* d_filter, \
- const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
- int* workspace, const convolution::ConvParam& param, \
- uint32_t nonlinear_mode, float alpha, float beta, \
- float gamma, float scale, \
- const GemmCoord& threadblock_shape, \
- const GemmCoord& warp_shape, cudaStream_t stream);
- INST(true);
- INST(false);
- #undef INST
-
- /* ==== cutlass kernel wrapper for nchw32 layout and nchw4 output ===== */
- #if MEGDNN_TEGRA_X1
- template <bool NeedLoadFromConstMem>
- void megdnn::cuda::cutlass_wrapper::
- do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4(
- const int8_t* /* d_src */, const int8_t* /* d_filter */,
- const int32_t* /* d_bias */, const int8_t* /* d_z */,
- int8_t* /* d_dst */, int* /* workspace */,
- const convolution::ConvParam& /* param */,
- uint32_t /* nonlinear_mode */, float /* alpha */,
- float /* beta */, float /* gamma */, float /* scale */,
- const GemmCoord& /* threadblock_shape */,
- const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
- #else
- template <bool NeedLoadFromConstMem>
- void megdnn::cuda::cutlass_wrapper::
- do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4(
- const int8_t* d_src, const int8_t* d_filter,
- const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst,
- int* workspace, const convolution::ConvParam& param,
- uint32_t nonlinear_mode, float alpha, float beta, float gamma,
- float scale, const GemmCoord& threadblock_shape,
- const GemmCoord& warp_shape, cudaStream_t stream) {
- #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
- threadblock_k_, warp_m_, warp_n_, \
- warp_k_) \
- if (threadblock_shape.m() == threadblock_m_ && \
- threadblock_shape.n() == threadblock_n_ && \
- threadblock_shape.k() == threadblock_k_ && \
- warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
- warp_shape.k() == warp_k_) { \
- using ThreadBlockShape = \
- cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
- threadblock_k_>; \
- using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
- using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; \
- using Convolution = cutlass::convolution::device::Convolution< \
- int8_t, cutlass::layout::TensorNCxHWx<32>, int8_t, \
- cutlass::layout::TensorCxRSKx<32>, ElementOutput, \
- cutlass::layout::TensorNCxHWx<4>, int32_t, \
- cutlass::layout::TensorNCxHWx<4>, int32_t, \
- cutlass::convolution::ConvType::kConvolution, \
- cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \
- ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
- cutlass::convolution::threadblock:: \
- ConvolutionNCxHWxThreadblockSwizzle< \
- cutlass::convolution::ConvType::kConvolution>, \
- 2, 16, 16, NeedLoadFromConstMem>; \
- typename Convolution::ConvolutionParameter conv_param{ \
- param.n, param.ci, param.co, param.hi, param.wi, \
- param.fh, param.fw, param.ho, param.wo, param.sh, \
- param.sw, param.ph, param.pw, 1, 1}; \
- return cutlass_convolution_wrapper<Convolution>( \
- d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \
- epilogue, stream); \
- }
- #define DISPATCH_KERNEL \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 16, 32, 64); \
- megdnn_assert(false, \
- "unsupported threadblock shape (%dx%dx%d) and warp shape " \
- "(%dx%dx%d)", \
- threadblock_shape.m(), threadblock_shape.n(), \
- threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
- warp_shape.k());
- using ElementOutput = int8_t;
- using ElementAccumulator = int32_t;
- using ElementBias = int32_t;
- using ElementCompute = float;
- using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
- switch (nonlinear_mode) {
- case NonlineMode::IDENTITY: {
- using EpilogueOp =
- cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
- ElementOutput, 4, ElementAccumulator, ElementBias,
- ElementCompute>;
- typename EpilogueOp::Params epilogue{alpha, beta, gamma};
- DISPATCH_KERNEL;
- }
- case NonlineMode::RELU: {
- using EpilogueOp = cutlass::epilogue::thread::
- BiasAddLinearCombinationReluClamp<
- ElementOutput, 4, ElementAccumulator, ElementBias,
- ElementCompute>;
- typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
- DISPATCH_KERNEL;
- }
- case NonlineMode::H_SWISH: {
- using EpilogueOp = cutlass::epilogue::thread::
- BiasAddLinearCombinationHSwishClamp<
- ElementOutput, 4, ElementAccumulator, ElementBias,
- ElementCompute>;
- typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
- DISPATCH_KERNEL;
- }
- default:
- megdnn_assert(false,
- "unsupported nonlinear mode for conv bias operator");
- }
- #undef DISPATCH_KERNEL_WITH_TILE_SHAPE
- #undef DISPATCH_KERNEL
- }
- #endif
-
- #define INST(need_load_from_const_mem) \
- template void megdnn::cuda::cutlass_wrapper:: \
- do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4< \
- need_load_from_const_mem>( \
- const int8_t* d_src, const int8_t* d_filter, \
- const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
- int* workspace, const convolution::ConvParam& param, \
- uint32_t nonlinear_mode, float alpha, float beta, \
- float gamma, float scale, \
- const GemmCoord& threadblock_shape, \
- const GemmCoord& warp_shape, cudaStream_t stream);
- INST(true);
- INST(false);
- #undef INST
-
- /* ================ cutlass kernel wrapper for nchw4 layout ================= */
- #if MEGDNN_TEGRA_X1
- template <bool NeedLoadFromConstMem>
- void megdnn::cuda::cutlass_wrapper::
- do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4(
- const int8_t* /* d_src */, const int8_t* /* d_filter */,
- const int32_t* /* d_bias */, const int8_t* /* d_z */,
- int8_t* /* d_dst */, int* /* workspace */,
- const convolution::ConvParam& /* param */,
- uint32_t /* nonlinear_mode */, float /* alpha */,
- float /* beta */, float /* gamma */, float /* scale */,
- const GemmCoord& /* threadblock_shape */,
- const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
- #else
- template <bool NeedLoadFromConstMem>
- void megdnn::cuda::cutlass_wrapper::
- do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4(
- const int8_t* d_src, const int8_t* d_filter,
- const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst,
- int* workspace, const convolution::ConvParam& param,
- uint32_t nonlinear_mode, float alpha, float beta, float gamma,
- float scale, const GemmCoord& threadblock_shape,
- const GemmCoord& warp_shape, cudaStream_t stream) {
- #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
- threadblock_k_, warp_m_, warp_n_, \
- warp_k_, stage_, aligned_) \
- if (threadblock_shape.m() == threadblock_m_ && \
- threadblock_shape.n() == threadblock_n_ && \
- threadblock_shape.k() == threadblock_k_ && \
- warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
- warp_shape.k() == warp_k_) { \
- using ThreadBlockShape = \
- cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
- threadblock_k_>; \
- using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
- using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \
- using Convolution = cutlass::convolution::device::Convolution< \
- int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \
- cutlass::layout::TensorCxRSKx<4>, ElementOutput, \
- cutlass::layout::TensorNCxHWx<4>, int32_t, \
- cutlass::layout::TensorNCxHWx<4>, int32_t, \
- cutlass::convolution::ConvType::kConvolution, \
- cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \
- ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
- cutlass::convolution::threadblock:: \
- ConvolutionNCxHWxThreadblockSwizzle< \
- cutlass::convolution::ConvType::kConvolution>, \
- stage_, 4, aligned_, NeedLoadFromConstMem>; \
- typename Convolution::ConvolutionParameter conv_param{ \
- param.n, param.ci, param.co, param.hi, param.wi, \
- param.fh, param.fw, param.ho, param.wo, param.sh, \
- param.sw, param.ph, param.pw, 1, 1}; \
- return cutlass_convolution_wrapper<Convolution>( \
- d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \
- epilogue, stream); \
- }
- #define DISPATCH_KERNEL \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \
- megdnn_assert(false, \
- "unsupported threadblock shape (%dx%dx%d) and warp shape " \
- "(%dx%dx%d)", \
- threadblock_shape.m(), threadblock_shape.n(), \
- threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
- warp_shape.k());
- using ElementOutput = int8_t;
- using ElementAccumulator = int32_t;
- using ElementBias = int32_t;
- using ElementCompute = float;
- using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
- switch (nonlinear_mode) {
- case NonlineMode::IDENTITY: {
- using EpilogueOp =
- cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
- ElementOutput, 4, ElementAccumulator, ElementBias,
- ElementCompute>;
- typename EpilogueOp::Params epilogue{alpha, beta, gamma};
- DISPATCH_KERNEL;
- }
- case NonlineMode::RELU: {
- using EpilogueOp = cutlass::epilogue::thread::
- BiasAddLinearCombinationReluClamp<
- ElementOutput, 4, ElementAccumulator, ElementBias,
- ElementCompute>;
- typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
- DISPATCH_KERNEL;
- }
- case NonlineMode::H_SWISH: {
- using EpilogueOp = cutlass::epilogue::thread::
- BiasAddLinearCombinationHSwishClamp<
- ElementOutput, 4, ElementAccumulator, ElementBias,
- ElementCompute>;
- typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
- DISPATCH_KERNEL;
- }
- default:
- megdnn_assert(false,
- "unsupported nonlinear mode for conv bias operator");
- }
- #undef DISPATCH_KERNEL_WITH_TILE_SHAPE
- #undef DISPATCH_KERNEL
- }
- #endif
-
- #define INST(need_load_from_const_mem) \
- template void megdnn::cuda::cutlass_wrapper:: \
- do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4< \
- need_load_from_const_mem>( \
- const int8_t* d_src, const int8_t* d_filter, \
- const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
- int* workspace, const convolution::ConvParam& param, \
- uint32_t nonlinear_mode, float alpha, float beta, \
- float gamma, float scale, \
- const GemmCoord& threadblock_shape, \
- const GemmCoord& warp_shape, cudaStream_t stream);
- INST(true);
- INST(false);
- #undef INST
-
- /* ===== cutlass kernel wrapper for nchw4 layout and nchw output ===== */
- #if MEGDNN_TEGRA_X1
- template <bool NeedLoadFromConstMem>
- void megdnn::cuda::cutlass_wrapper::
- do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw(
- const int8_t* /* d_src */, const int8_t* /* d_filter */,
- const float* /* d_bias */, const float* /* d_z */,
- float* /* d_dst */, int* /* workspace */,
- const convolution::ConvParam& /* param */,
- uint32_t /* nonlinear_mode */, float /* alpha */,
- float /* beta */, float /* gamma */, float /* scale */,
- const GemmCoord& /* threadblock_shape */,
- const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
- #else
- template <bool NeedLoadFromConstMem>
- void megdnn::cuda::cutlass_wrapper::
- do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw(
- const int8_t* d_src, const int8_t* d_filter,
- const float* d_bias, const float* d_z, float* d_dst,
- int* workspace, const convolution::ConvParam& param,
- uint32_t nonlinear_mode, float alpha, float beta, float gamma,
- float scale, const GemmCoord& threadblock_shape,
- const GemmCoord& warp_shape, cudaStream_t stream) {
- #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
- threadblock_k_, warp_m_, warp_n_, \
- warp_k_, aligned_) \
- if (threadblock_shape.m() == threadblock_m_ && \
- threadblock_shape.n() == threadblock_n_ && \
- threadblock_shape.k() == threadblock_k_ && \
- warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
- warp_shape.k() == warp_k_) { \
- using ThreadBlockShape = \
- cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
- threadblock_k_>; \
- using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
- using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \
- using Convolution = cutlass::convolution::device::Convolution< \
- int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \
- cutlass::layout::TensorCxRSKx<4>, ElementOutput, \
- cutlass::layout::TensorNCHW, float, \
- cutlass::layout::TensorNCHW, int32_t, \
- cutlass::convolution::ConvType::kConvolution, \
- cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \
- ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
- cutlass::convolution::threadblock:: \
- ConvolutionNCxHWxThreadblockSwizzle< \
- cutlass::convolution::ConvType::kConvolution>, \
- 2, 4, aligned_, NeedLoadFromConstMem, \
- cutlass::arch::OpMultiplyAdd>; \
- typename Convolution::ConvolutionParameter conv_param{ \
- param.n, param.ci, param.co, param.hi, param.wi, \
- param.fh, param.fw, param.ho, param.wo, param.sh, \
- param.sw, param.ph, param.pw, 1, 1}; \
- return cutlass_convolution_wrapper<Convolution>( \
- d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \
- epilogue, stream); \
- }
- #define DISPATCH_KERNEL \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 4); \
- megdnn_assert(false, \
- "unsupported threadblock shape (%dx%dx%d) and warp shape " \
- "(%dx%dx%d)", \
- threadblock_shape.m(), threadblock_shape.n(), \
- threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
- warp_shape.k());
- using ElementOutput = float;
- using ElementAccumulator = int32_t;
- using ElementBias = float;
- using ElementCompute = float;
- using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
- switch (nonlinear_mode) {
- case NonlineMode::IDENTITY: {
- using EpilogueOp =
- cutlass::epilogue::thread::BiasAddLinearCombination<
- ElementOutput, 1, ElementAccumulator, ElementBias,
- ElementCompute>;
- typename EpilogueOp::Params epilogue{alpha, beta, gamma};
- DISPATCH_KERNEL;
- }
- case NonlineMode::RELU: {
- using EpilogueOp =
- cutlass::epilogue::thread::BiasAddLinearCombinationRelu<
- ElementOutput, 1, ElementAccumulator, ElementBias,
- ElementCompute>;
- typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
- DISPATCH_KERNEL;
- }
- case NonlineMode::H_SWISH: {
- using EpilogueOp =
- cutlass::epilogue::thread::BiasAddLinearCombinationHSwish<
- ElementOutput, 1, ElementAccumulator, ElementBias,
- ElementCompute>;
- typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
- DISPATCH_KERNEL;
- }
- default:
- megdnn_assert(false,
- "unsupported nonlinear mode for conv bias operator");
- }
- #undef DISPATCH_KERNEL_WITH_TILE_SHAPE
- #undef DISPATCH_KERNEL
- }
- #endif
-
- #define INST(need_load_from_const_mem) \
- template void megdnn::cuda::cutlass_wrapper:: \
- do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw< \
- need_load_from_const_mem>( \
- const int8_t* d_src, const int8_t* d_filter, \
- const float* d_bias, const float* d_z, float* d_dst, \
- int* workspace, const convolution::ConvParam& param, \
- uint32_t nonlinear_mode, float alpha, float beta, \
- float gamma, float scale, \
- const GemmCoord& threadblock_shape, \
- const GemmCoord& warp_shape, cudaStream_t stream);
- INST(true);
- INST(false);
- #undef INST
-
- /* ====== cutlass kernel wrapper for nchw4 layout and nchw32 output ====== */
- #if MEGDNN_TEGRA_X1
- template <bool NeedLoadFromConstMem>
- void megdnn::cuda::cutlass_wrapper::
- do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32(
- const int8_t* /* d_src */, const int8_t* /* d_filter */,
- const int32_t* /* d_bias */, const int8_t* /* d_z */,
- int8_t* /* d_dst */, int* /* workspace */,
- const convolution::ConvParam& /* param */,
- uint32_t /* nonlinear_mode */, float /* alpha */,
- float /* beta */, float /* gamma */, float /* scale */,
- const GemmCoord& /* threadblock_shape */,
- const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
- #else
- template <bool NeedLoadFromConstMem>
- void megdnn::cuda::cutlass_wrapper::
- do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32(
- const int8_t* d_src, const int8_t* d_filter,
- const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst,
- int* workspace, const convolution::ConvParam& param,
- uint32_t nonlinear_mode, float alpha, float beta, float gamma,
- float scale, const GemmCoord& threadblock_shape,
- const GemmCoord& warp_shape, cudaStream_t stream) {
- #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
- threadblock_k_, warp_m_, warp_n_, \
- warp_k_, aligned_) \
- if (threadblock_shape.m() == threadblock_m_ && \
- threadblock_shape.n() == threadblock_n_ && \
- threadblock_shape.k() == threadblock_k_ && \
- warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
- warp_shape.k() == warp_k_) { \
- using ThreadBlockShape = \
- cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
- threadblock_k_>; \
- using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
- using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \
- using Convolution = cutlass::convolution::device::Convolution< \
- int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \
- cutlass::layout::TensorCxRSKx<4>, ElementOutput, \
- cutlass::layout::TensorNCxHWx<32>, int32_t, \
- cutlass::layout::TensorNCxHWx<32>, int32_t, \
- cutlass::convolution::ConvType::kConvolution, \
- cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \
- ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
- cutlass::convolution::threadblock:: \
- ConvolutionNCxHWxThreadblockSwizzle< \
- cutlass::convolution::ConvType::kConvolution>, \
- 2, 4, aligned_, NeedLoadFromConstMem>; \
- typename Convolution::ConvolutionParameter conv_param{ \
- param.n, param.ci, param.co, param.hi, param.wi, \
- param.fh, param.fw, param.ho, param.wo, param.sh, \
- param.sw, param.ph, param.pw, 1, 1}; \
- return cutlass_convolution_wrapper<Convolution>( \
- d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \
- epilogue, stream); \
- }
- #define DISPATCH_KERNEL \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 16); \
- DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 16); \
- megdnn_assert(false, \
- "unsupported threadblock shape (%dx%dx%d) and warp shape " \
- "(%dx%dx%d)", \
- threadblock_shape.m(), threadblock_shape.n(), \
- threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
- warp_shape.k());
- using ElementOutput = int8_t;
- using ElementAccumulator = int32_t;
- using ElementBias = int32_t;
- using ElementCompute = float;
- using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
- switch (nonlinear_mode) {
- case NonlineMode::IDENTITY: {
- using EpilogueOp =
- cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
- ElementOutput, 4, ElementAccumulator, ElementBias,
- ElementCompute>;
- typename EpilogueOp::Params epilogue{alpha, beta, gamma};
- DISPATCH_KERNEL;
- }
- case NonlineMode::RELU: {
- using EpilogueOp = cutlass::epilogue::thread::
- BiasAddLinearCombinationReluClamp<
- ElementOutput, 4, ElementAccumulator, ElementBias,
- ElementCompute>;
- typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
- DISPATCH_KERNEL;
- }
- case NonlineMode::H_SWISH: {
- using EpilogueOp = cutlass::epilogue::thread::
- BiasAddLinearCombinationHSwishClamp<
- ElementOutput, 4, ElementAccumulator, ElementBias,
- ElementCompute>;
- typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
- DISPATCH_KERNEL;
- }
- default:
- megdnn_assert(false,
- "unsupported nonlinear mode for conv bias operator");
- }
- #undef DISPATCH_KERNEL_WITH_TILE_SHAPE
- #undef DISPATCH_KERNEL
- }
- #endif
-
- #define INST(need_load_from_const_mem) \
- template void megdnn::cuda::cutlass_wrapper:: \
- do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< \
- need_load_from_const_mem>( \
- const int8_t* d_src, const int8_t* d_filter, \
- const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
- int* workspace, const convolution::ConvParam& param, \
- uint32_t nonlinear_mode, float alpha, float beta, \
- float gamma, float scale, \
- const GemmCoord& threadblock_shape, \
- const GemmCoord& warp_shape, cudaStream_t stream);
- INST(true);
- INST(false);
- #undef INST
-
- // vim: syntax=cuda.doxygen
|