diff --git a/dnn/scripts/cutlass_generator/BUILD b/dnn/scripts/cutlass_generator/BUILD index bb1e0b5f..361bb885 100644 --- a/dnn/scripts/cutlass_generator/BUILD +++ b/dnn/scripts/cutlass_generator/BUILD @@ -9,6 +9,7 @@ genrule( CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type tensorop1688 $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemv --type simt $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type simt $(@D) + CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type tensorop8816 $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type simt $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type tensorop8816 $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type tensorop8832 $(@D) diff --git a/dnn/scripts/cutlass_generator/conv2d_operation.py b/dnn/scripts/cutlass_generator/conv2d_operation.py index 539574b9..9e267fcb 100644 --- a/dnn/scripts/cutlass_generator/conv2d_operation.py +++ b/dnn/scripts/cutlass_generator/conv2d_operation.py @@ -337,7 +337,10 @@ def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_lay else: swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx else: - swizzling_functor = SwizzlingFunctor.ConvDgradNCxHWx + if implicit_gemm_mode == ImplicitGemmMode.GemmTN: + swizzling_functor = SwizzlingFunctor.ConvDgradTrans + else: + swizzling_functor = SwizzlingFunctor.ConvDgradNCxHWx # skip rule def filter_tile_with_layout(tile: TileDescription, layout: LayoutType) -> bool: diff --git a/dnn/scripts/cutlass_generator/gen_list.py b/dnn/scripts/cutlass_generator/gen_list.py index d3b06776..fc5f980b 100644 --- a/dnn/scripts/cutlass_generator/gen_list.py +++ b/dnn/scripts/cutlass_generator/gen_list.py @@ -36,6 +36,7 @@ if __name__ == "__main__": write_op_list(f, "gemm", "tensorop884") write_op_list(f, "gemv", "simt") write_op_list(f, "deconv", "simt") + write_op_list(f, "deconv", "tensorop8816") write_op_list(f, "conv2d", "simt") write_op_list(f, "conv2d", "tensorop8816") write_op_list(f, "conv2d", "tensorop8832") diff --git a/dnn/scripts/cutlass_generator/generator.py b/dnn/scripts/cutlass_generator/generator.py index b5814746..91881e26 100644 --- a/dnn/scripts/cutlass_generator/generator.py +++ b/dnn/scripts/cutlass_generator/generator.py @@ -445,6 +445,53 @@ def GenerateDeconv_Simt(args): use_special_optimization) return operations +def GenerateDeconv_TensorOp_8816(args): + operations = [] + + layouts = [ + (LayoutType.TensorNHWC, LayoutType.TensorCK4RS4, 32), + (LayoutType.TensorNHWC, LayoutType.TensorCK8RS8, 64), + (LayoutType.TensorNHWC, LayoutType.TensorCK16RS16, 128), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 16], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + dst_layouts = [ + LayoutType.TensorNHWC, + ] + + dst_types = [ + DataType.s8, + ] + + use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling + + min_cc = 75 + max_cc = 1024 + + cuda_major = 10 + cuda_minor = 2 + + for math_inst in math_instructions: + for layout in layouts: + for dst_type, dst_layout in zip(dst_types, dst_layouts): + tile_descriptions = [ + TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), + ] + for tile in tile_descriptions: + dst_align = 32 if tile.threadblock_shape[1] == 16 else 64 + operations += GenerateConv2d(ConvKind.Dgrad, [tile], layout[0], layout[1], dst_layout, dst_type, + min_cc, layout[2], layout[2], dst_align, use_special_optimization, + ImplicitGemmMode.GemmTN, False, cuda_major, cuda_minor) + return operations + ################################################################################ # parameters # Edge - for tiles, the edges represent the length of one side @@ -820,9 +867,12 @@ def GenerateConv2dOperations(args): return GenerateConv2d_TensorOp_8832(args) def GenerateDeconvOperations(args): - assert args.type == "simt", "operation deconv only support" \ - "simt. (got:{})".format(args.type) - return GenerateDeconv_Simt(args) + if args.type == "simt": + return GenerateDeconv_Simt(args) + else: + assert args.type == "tensorop8816", "operation deconv only support" \ + "simt and tensorop8816. (got:{})".format(args.type) + return GenerateDeconv_TensorOp_8816(args) def GenerateGemmOperations(args): if args.type == "tensorop884": diff --git a/dnn/scripts/cutlass_generator/library.py b/dnn/scripts/cutlass_generator/library.py index 9308357b..466ddc25 100644 --- a/dnn/scripts/cutlass_generator/library.py +++ b/dnn/scripts/cutlass_generator/library.py @@ -280,6 +280,9 @@ class LayoutType(enum.Enum): TensorC32RSK32 = enum_auto() TensorC64RSK64 = enum_auto() TensorK4RSC4 = enum_auto() + TensorCK4RS4 = enum_auto() + TensorCK8RS8 = enum_auto() + TensorCK16RS16 = enum_auto() # LayoutTag = { @@ -303,7 +306,10 @@ LayoutTag = { LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>', LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>', LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>', - LayoutType.TensorK4RSC4: 'cutlass::layout::TensorKxRSCx<4>', + LayoutType.TensorK4RSC4: 'cutlass::layout::TensorKxRSCx<4>', + LayoutType.TensorCK4RS4: 'cutlass::layout::TensorCKxRSx<4>', + LayoutType.TensorCK8RS8: 'cutlass::layout::TensorCKxRSx<8>', + LayoutType.TensorCK16RS16: 'cutlass::layout::TensorCKxRSx<16>', } # @@ -342,6 +348,9 @@ ShortLayoutTypeNames = { LayoutType.TensorC32RSK32: 'c32rsk32', LayoutType.TensorC64RSK64: 'c64rsk64', LayoutType.TensorK4RSC4: 'k4rsc4', + LayoutType.TensorCK4RS4: 'ck4rs4', + LayoutType.TensorCK8RS8: 'ck8rs8', + LayoutType.TensorCK16RS16: 'ck16rs16', } # @@ -484,6 +493,7 @@ class SwizzlingFunctor(enum.Enum): ConvFpropNCxHWx = enum_auto() ConvFpropTrans = enum_auto() ConvDgradNCxHWx = enum_auto() + ConvDgradTrans = enum_auto() # SwizzlingFunctorTag = { @@ -494,6 +504,7 @@ SwizzlingFunctorTag = { SwizzlingFunctor.ConvFpropNCxHWx: 'cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle', SwizzlingFunctor.ConvFpropTrans: 'cutlass::conv::threadblock::ConvolutionFpropTransThreadblockSwizzle', SwizzlingFunctor.ConvDgradNCxHWx: 'cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle', + SwizzlingFunctor.ConvDgradTrans: 'cutlass::conv::threadblock::ConvolutionDgradTransThreadblockSwizzle', } ################################################################################################### diff --git a/dnn/scripts/cutlass_generator/list.bzl b/dnn/scripts/cutlass_generator/list.bzl index 52b16821..596b4561 100644 --- a/dnn/scripts/cutlass_generator/list.bzl +++ b/dnn/scripts/cutlass_generator/list.bzl @@ -464,6 +464,19 @@ cutlass_gen_list = [ "cutlass_simt_s8_idgrad_id_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", "cutlass_simt_s8_idgrad_s2_id_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", "all_deconv_simt_operations.cu", + "cutlass_tensorop_s8_i8816dgrad_id_s8_128x32x32_64x32x32_1_nhwc_ck4rs4.cu", + "cutlass_tensorop_s8_i8816dgrad_s2_id_s8_128x32x32_64x32x32_1_nhwc_ck4rs4.cu", + "cutlass_tensorop_s8_i8816dgrad_id_s8_64x16x32_64x16x32_2_nhwc_ck4rs4.cu", + "cutlass_tensorop_s8_i8816dgrad_s2_id_s8_64x16x32_64x16x32_2_nhwc_ck4rs4.cu", + "cutlass_tensorop_s8_i8816dgrad_id_s8_128x32x32_64x32x32_1_nhwc_ck8rs8.cu", + "cutlass_tensorop_s8_i8816dgrad_s2_id_s8_128x32x32_64x32x32_1_nhwc_ck8rs8.cu", + "cutlass_tensorop_s8_i8816dgrad_id_s8_64x16x32_64x16x32_2_nhwc_ck8rs8.cu", + "cutlass_tensorop_s8_i8816dgrad_s2_id_s8_64x16x32_64x16x32_2_nhwc_ck8rs8.cu", + "cutlass_tensorop_s8_i8816dgrad_id_s8_128x32x32_64x32x32_1_nhwc_ck16rs16.cu", + "cutlass_tensorop_s8_i8816dgrad_s2_id_s8_128x32x32_64x32x32_1_nhwc_ck16rs16.cu", + "cutlass_tensorop_s8_i8816dgrad_id_s8_64x16x32_64x16x32_2_nhwc_ck16rs16.cu", + "cutlass_tensorop_s8_i8816dgrad_s2_id_s8_64x16x32_64x16x32_2_nhwc_ck16rs16.cu", + "all_deconv_tensorop8816_operations.cu", "cutlass_simt_s8_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", "cutlass_simt_s8_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", "cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", diff --git a/dnn/src/CMakeLists.txt b/dnn/src/CMakeLists.txt index a2e9b9b2..a24448fa 100644 --- a/dnn/src/CMakeLists.txt +++ b/dnn/src/CMakeLists.txt @@ -155,6 +155,7 @@ if(MGE_WITH_CUDA) gen_cutlass_kimpl(gemm tensorop1688 CUTLASS_SOURCES) gen_cutlass_kimpl(gemv simt CUTLASS_SOURCES) gen_cutlass_kimpl(deconv simt CUTLASS_SOURCES) + gen_cutlass_kimpl(deconv tensorop8816 CUTLASS_SOURCES) gen_cutlass_kimpl(conv2d simt CUTLASS_SOURCES) gen_cutlass_kimpl(conv2d tensorop8816 CUTLASS_SOURCES) gen_cutlass_kimpl(conv2d tensorop8832 CUTLASS_SOURCES) diff --git a/dnn/src/cuda/convolution/backward_data/algo.cpp b/dnn/src/cuda/convolution/backward_data/algo.cpp index b4a24325..2bde8791 100644 --- a/dnn/src/cuda/convolution/backward_data/algo.cpp +++ b/dnn/src/cuda/convolution/backward_data/algo.cpp @@ -36,6 +36,12 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { int8_algos.push_back(&algo); } + fill_int8_imma_algos(); + for (auto&& algo : int8_nhwc_imma) { + all_algos.push_back(&algo); + int8_algos.push_back(&algo); + } + int8_algos.push_back(&int8_nchw_dotprod); all_algos.push_back(&int8_nchw_dotprod); diff --git a/dnn/src/cuda/convolution/backward_data/algo.h b/dnn/src/cuda/convolution/backward_data/algo.h index f87479e5..821ae5f2 100644 --- a/dnn/src/cuda/convolution/backward_data/algo.h +++ b/dnn/src/cuda/convolution/backward_data/algo.h @@ -40,7 +40,8 @@ public: CUDA_BFLOAT16, CUDA_GROUP_CONV_GENERAL, CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8, - CUDA_IMPLICIT_GEMM_NCHW_DOTPROD_INT8 + CUDA_IMPLICIT_GEMM_NCHW_DOTPROD_INT8, + CUDA_IMPLICIT_GEMM_NHWC_IMMA_INT8 }; using Mapper = std::unordered_map; @@ -299,11 +300,53 @@ private: const void* get_available_op(const SizeArgs& args) const; }; +class ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm final + : public AlgoBase { +public: + struct AlgoParam { + int threadblock_m; + int threadblock_n; + int threadblock_k; + int warp_m; + int warp_n; + int warp_k; + int stage; + int access_size; + std::string to_string() { + return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage_%d", threadblock_m, + threadblock_n, threadblock_k, warp_m, warp_n, + warp_k, stage, access_size); + } + }; + AlgoInt8NHWCIMMAImplicitGemm(AlgoParam algo_param) + : m_algo_param{algo_param}, + m_name{ssprintf("INT8_NHWC_IMMA_IMPLICIT_GEMM%s", + m_algo_param.to_string().c_str())} {} + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + const char* name() const override { return m_name.c_str(); } + AlgoAttribute attribute() const override { + return AlgoAttribute::REPRODUCIBLE; + } + MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NHWC_IMMA_INT8) +private: + WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, + const SizeArgs& args) const; + const void* get_available_op(const SizeArgs& args) const; + void reorder_filter(const ExecArgs& args, const int iterleaved, + int8_t* reordered_filter) const; + AlgoParam m_algo_param; + std::string m_name; +}; + class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { // defined in cudnn.cpp void fill_cudnn_algos(); // defined in implicit_gemm_int8_nchw4_dp4a.cpp void fill_int8_dp4a_algos(); + // defined in implicit_gemm_int8_nhwc_imma.cpp + void fill_int8_imma_algos(); AlgoBase::Mapper m_all_algos_map; @@ -318,6 +361,7 @@ public: AlgoGroupConvGeneral group; std::vector int8_nchw4_dotprod; AlgoInt8NCHWDotProdImplicitGemm int8_nchw_dotprod; + std::vector int8_nhwc_imma; std::vector //! all algorithms diff --git a/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cu b/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cu index f3d284c3..2df449f5 100644 --- a/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cu +++ b/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cu @@ -11,6 +11,7 @@ */ #include "src/cuda/convolution/backward_data/deconv_int8_helper.cuh" +#include "src/cuda/transpose_utils.cuh" using namespace megdnn; using namespace cuda; @@ -21,7 +22,6 @@ using namespace deconv; namespace { -// __global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel( int8_t* __restrict__ dst, const int8_t* __restrict__ src, uint32_t OC, uint32_t IC, uint32_t FHFW) { @@ -30,32 +30,55 @@ __global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel( const int32_t fhfw = blockIdx.x * BLOCKSIZE_Y + threadIdx.x; if (fhfw < FHFW && icb < IC / 4) { - int src0 = *reinterpret_cast( - src + (ocb * 4 + 0) * IC * FHFW + (icb * FHFW + fhfw) * 4); - int src1 = *reinterpret_cast( - src + (ocb * 4 + 1) * IC * FHFW + (icb * FHFW + fhfw) * 4); - int src2 = *reinterpret_cast( - src + (ocb * 4 + 2) * IC * FHFW + (icb * FHFW + fhfw) * 4); - int src3 = *reinterpret_cast( - src + (ocb * 4 + 3) * IC * FHFW + (icb * FHFW + fhfw) * 4); + int src_value[4], dst_value[4]; +#pragma unroll + for (int i = 0; i < 4; i++) { + src_value[i] = *reinterpret_cast( + src + (ocb * 4 + i) * IC * FHFW + (icb * FHFW + fhfw) * 4); + } + // transpose 4x4 - int dst01_lo = __byte_perm(src0, src1, 0x5140); - int dst01_hi = __byte_perm(src0, src1, 0x7362); - int dst23_lo = __byte_perm(src2, src3, 0x5140); - int dst23_hi = __byte_perm(src2, src3, 0x7362); - int dst0 = __byte_perm(dst01_lo, dst23_lo, 0x5410); - int dst1 = __byte_perm(dst01_lo, dst23_lo, 0x7632); - int dst2 = __byte_perm(dst01_hi, dst23_hi, 0x5410); - int dst3 = __byte_perm(dst01_hi, dst23_hi, 0x7632); - - *reinterpret_cast( - dst + (ocb * FHFW * IC + fhfw * IC + icb * 4 + 0) * 4) = dst0; - *reinterpret_cast( - dst + (ocb * FHFW * IC + fhfw * IC + icb * 4 + 1) * 4) = dst1; - *reinterpret_cast( - dst + (ocb * FHFW * IC + fhfw * IC + icb * 4 + 2) * 4) = dst2; - *reinterpret_cast( - dst + (ocb * FHFW * IC + fhfw * IC + icb * 4 + 3) * 4) = dst3; + transpose_int8_interleavedx4<4, int>(src_value, dst_value); + +#pragma unroll + for (int i = 0; i < 4; i++) { + *reinterpret_cast( + dst + (ocb * FHFW * IC + fhfw * IC + icb * 4 + i) * 4) = + dst_value[i]; + } + } +} + +template +__global__ void reorder_filter_nhwc_to_cnxhwx_kernel( + int8_t* __restrict__ dst, const int8_t* __restrict__ src, uint32_t OC, + uint32_t IC, uint32_t FHFW) { + uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x; + const int32_t ocb = lane / (FHFW * IC / 4); + const int32_t fhfw_icb = lane % (FHFW * IC / 4); + const int32_t fhfw = fhfw_icb / (IC / 4); + const int32_t icb = fhfw_icb % (IC / 4); + + if (ocb < OC / interleaved && fhfw < FHFW) { + int src_value[interleaved]; + vec_type dst_value[4]; + +#pragma unroll + for (int i = 0; i < interleaved; i++) { + src_value[i] = *reinterpret_cast( + src + (ocb * interleaved + i) * FHFW * IC + fhfw * IC + + icb * 4); + } + + transpose_int8_interleavedx4(src_value, + dst_value); + +#pragma unroll + for (int i = 0; i < 4; i++) { + *reinterpret_cast(dst + (icb * 4 + i) * FHFW * OC + + (ocb * FHFW + fhfw) * interleaved) = + dst_value[i]; + } } } @@ -73,4 +96,27 @@ void megdnn::cuda::deconv::reorder_filter_nc4hw4_to_n4hwc4( after_kernel_launch(); } +void megdnn::cuda::deconv::reorder_filter_nhwc_to_cnxhwx( + int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH, + uint32_t FW, uint32_t interleaved, cudaStream_t stream) { + int32_t vthreads = OC / interleaved * IC / 4 * FH * FW; + int32_t nr_threads = std::min(256, vthreads); + int32_t nr_blocks = DIVUP(vthreads, nr_threads); + + if (interleaved == 4) { + reorder_filter_nhwc_to_cnxhwx_kernel<4, int> + <<>>(dst, src, OC, IC, + FH * FW); + } else if (interleaved == 8) { + reorder_filter_nhwc_to_cnxhwx_kernel<8, int2> + <<>>(dst, src, OC, IC, + FH * FW); + } else { + reorder_filter_nhwc_to_cnxhwx_kernel<16, int4> + <<>>(dst, src, OC, IC, + FH * FW); + } + after_kernel_launch(); +} + // vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cuh b/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cuh index f50b3c36..ea7baaee 100644 --- a/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cuh +++ b/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cuh @@ -20,6 +20,10 @@ void reorder_filter_nc4hw4_to_n4hwc4(int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, cudaStream_t stream); +void reorder_filter_nhwc_to_cnxhwx(int8_t* dst, const int8_t* src, uint32_t OC, + uint32_t IC, uint32_t FH, uint32_t FW, + uint32_t interleaved, cudaStream_t stream); + } // namespace deconv } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nhwc_imma.cpp b/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nhwc_imma.cpp new file mode 100644 index 00000000..f2e7903f --- /dev/null +++ b/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nhwc_imma.cpp @@ -0,0 +1,214 @@ +/** + * \file + * dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/cuda/convolution/backward_data/algo.h" +#include "src/cuda/convolution/backward_data/deconv_int8_helper.cuh" +#include "src/cuda/convolution_helper/parameter.cuh" +#include "src/cuda/cutlass/singleton.h" +#include "src/cuda/utils.h" + +using namespace megdnn; +using namespace cuda; + +const void* +ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::get_available_op( + const SizeArgs& args) const { + using namespace cutlass::library; + auto&& fm = args.filter_meta; + size_t sh = fm.stride[0], sw = fm.stride[1]; + cutlass::conv::SpecialOptimizeDesc special_optimization = + (sh == 2 && sw == 2) ? cutlass::conv::SpecialOptimizeDesc:: + DECONV_DOUBLE_UPSAMPLING + : cutlass::conv::SpecialOptimizeDesc::NONE; + LayoutTypeID filter_layout; + if (m_algo_param.access_size == 16) { + filter_layout = LayoutTypeID::kTensorCK16RS16; + } else if (m_algo_param.access_size == 8) { + filter_layout = LayoutTypeID::kTensorCK8RS8; + } else { + megdnn_assert(m_algo_param.access_size == 4, "invalid access_size: %d", + m_algo_param.access_size); + filter_layout = LayoutTypeID::kTensorCK4RS4; + } + ConvolutionKey key{ + cutlass::conv::Operator::kDgrad, + NumericTypeID::kS8, + LayoutTypeID::kTensorNHWC, + NumericTypeID::kS8, + filter_layout, + NumericTypeID::kS8, + LayoutTypeID::kTensorNHWC, + NumericTypeID::kS32, + LayoutTypeID::kTensorNHWC, + cutlass::conv::ConvType::kConvolution, + m_algo_param.threadblock_m, + m_algo_param.threadblock_n, + m_algo_param.threadblock_k, + m_algo_param.warp_m, + m_algo_param.warp_n, + m_algo_param.warp_k, + 8, + 8, + 16, + cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp, + m_algo_param.stage, + special_optimization, + false}; + return (void*)Singleton::get().operation_table.find_op(key); +} + +bool ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::is_available( + const SizeArgs& args) const { + auto&& fm = args.filter_meta; + if (fm.format != Param::Format::NHWC) + return false; + + if (!args.grad_layout->is_contiguous() || + !args.diff_layout->is_contiguous()) { + return false; + } + + bool available = true; + + auto src_dtype = args.diff_layout->dtype, + filter_dtype = args.filter_layout->dtype, + dst_dtype = args.grad_layout->dtype; + size_t co = args.diff_layout->operator[](3); + size_t ci = args.grad_layout->operator[](3); + + available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && + filter_dtype.enumv() == DTypeEnum::QuantizedS8 && + dst_dtype.enumv() == DTypeEnum::QuantizedS8); + // TODO support group deconv int8 + available &= (fm.group == 1); + // mode must be cross correlation + available &= !fm.should_flip; + // mode must be 2D + available &= fm.spatial_ndim == 2; + // TODO: support dialtion + available &= (fm.dilation[0] == 1 && fm.dilation[1] == 1); + // FIXME: too large filter size is not supported now + size_t kMaxFilterPixels = + 848 / (m_algo_param.warp_k / m_algo_param.access_size) - 1; + available &= fm.spatial[0] * fm.spatial[1] <= kMaxFilterPixels; + // ci should be aligned with 4, and co should be aligned with + // algo_param.access_size + available &= ((ci % 4 == 0) && (co % m_algo_param.access_size == 0)); + available &= (get_available_op(args) != nullptr); + // only support sm_75 or later, platform should have imma int8 support + available &= is_compute_capability_required(7, 5); + + return available; +} + +WorkspaceBundle +ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::get_workspace_bundle( + dt_byte* raw_ptr, const SizeArgs& args) const { + size_t ws_filter = args.filter_layout->span().dist_byte(); + return WorkspaceBundle{raw_ptr, {ws_filter}}; +} + +size_t ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm:: + get_workspace_in_bytes(const SizeArgs& args) const { + return get_workspace_bundle(nullptr, args).total_size_in_bytes(); +} + +void ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::exec( + const ExecArgs& args) const { + auto&& param = args.opr->param(); + auto&& fm = args.filter_meta; + size_t n = args.diff_layout->operator[](0), + co = args.diff_layout->operator[](3), + ho = args.diff_layout->operator[](1), + wo = args.diff_layout->operator[](2); + size_t ci = args.grad_layout->operator[](3), + hi = args.grad_layout->operator[](1), + wi = args.grad_layout->operator[](2); + size_t fh = fm.spatial[0], fw = fm.spatial[1]; + size_t sh = fm.stride[0], sw = fm.stride[1]; + size_t ph = fm.padding[0], pw = fm.padding[1]; + size_t dh = param.dilate_h, dw = param.dilate_w; + + auto&& stream = cuda_stream(args.opr->handle()); + + int8_t* filter_ptr = nullptr; + // TODO: weight preprocess + { + filter_ptr = reinterpret_cast(args.workspace.raw_ptr); + // reformat filter from nc4hw4 to n4hwc4 + reorder_filter(args, m_algo_param.access_size, filter_ptr); + } + + float diff_scale = + args.diff_layout->dtype.param().scale, + filter_scale = + args.filter_layout->dtype.param().scale, + grad_scale = + args.grad_layout->dtype.param().scale; + + // \note these constants of cutlass epilogue will be passed to struct + // `ConvolutionArguments` by pointer and interpreted as ElementCompute*, + // a different dtype here results in undefined epilogue behaviors + float alpha = diff_scale * filter_scale / grad_scale, beta = 0.f, + gamma = 0.f, delta = 0.f; + + using namespace cutlass::library; + + const Operation* op = (const Operation*)get_available_op(args); + + // gcc prints warnings when size_t values are implicitly narrowed to int + cutlass::conv::Conv2dProblemSize problem_size{ + int(n), int(hi), int(wi), int(ci), + int(co), int(fh), int(fw), int(ho), + int(wo), int(ph), int(pw), int(sh), + int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation}; + + cutlass::library::ConvolutionArguments conv_args{ + problem_size, args.diff_tensor->compatible_ptr(), + filter_ptr, nullptr, + nullptr, args.grad_tensor->compatible_ptr(), + &alpha, &beta, + &gamma, &delta, + nullptr, nullptr, + nullptr, nullptr}; + + cutlass_check(op->run(&conv_args, nullptr, stream)); + + after_kernel_launch(); +} + +void ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::reorder_filter( + const ExecArgs& args, const int interleaved, + int8_t* reordered_filter) const { + auto&& fm = args.filter_meta; + size_t co = args.diff_layout->operator[](3); + size_t ci = args.grad_layout->operator[](3); + size_t fh = fm.spatial[0], fw = fm.spatial[1]; + + auto&& stream = cuda_stream(args.opr->handle()); + megdnn::cuda::deconv::reorder_filter_nhwc_to_cnxhwx( + reordered_filter, args.filter_tensor->compatible_ptr(), co, + ci, fh, fw, interleaved, stream); +} + +void ConvolutionBackwardDataImpl::AlgoPack::fill_int8_imma_algos() { + using AlgoParam = AlgoInt8NHWCIMMAImplicitGemm::AlgoParam; + int8_nhwc_imma.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 4}); + int8_nhwc_imma.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 8}); + int8_nhwc_imma.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 16}); + int8_nhwc_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 4}); + int8_nhwc_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 8}); + int8_nhwc_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 16}); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution/opr_impl.h b/dnn/src/cuda/convolution/opr_impl.h index 226a34a6..08bc4e75 100644 --- a/dnn/src/cuda/convolution/opr_impl.h +++ b/dnn/src/cuda/convolution/opr_impl.h @@ -99,6 +99,7 @@ public: class AlgoBFloat16; class AlgoInt8NCHW4DotProdImplicitGemm; class AlgoInt8NCHWDotProdImplicitGemm; + class AlgoInt8NHWCIMMAImplicitGemm; class AlgoPack; diff --git a/dnn/src/cuda/cutlass/initialize_all.cu b/dnn/src/cuda/cutlass/initialize_all.cu index 89d32c5d..e314724f 100644 --- a/dnn/src/cuda/cutlass/initialize_all.cu +++ b/dnn/src/cuda/cutlass/initialize_all.cu @@ -60,6 +60,7 @@ void initialize_all_gemm_tensorop884_operations(Manifest& manifest); void initialize_all_gemm_tensorop1688_operations(Manifest& manifest); void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest); void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest); +void initialize_all_deconv_tensorop8816_operations(Manifest& manifest); #endif void initialize_all(Manifest& manifest) { @@ -71,6 +72,7 @@ void initialize_all(Manifest& manifest) { initialize_all_gemm_tensorop1688_operations(manifest); initialize_all_conv2d_tensorop8816_operations(manifest); initialize_all_conv2d_tensorop8832_operations(manifest); + initialize_all_deconv_tensorop8816_operations(manifest); #endif } diff --git a/dnn/src/cuda/cutlass/library.h b/dnn/src/cuda/cutlass/library.h index b9dd76a2..4d7a5b05 100644 --- a/dnn/src/cuda/cutlass/library.h +++ b/dnn/src/cuda/cutlass/library.h @@ -100,6 +100,9 @@ enum class LayoutTypeID { kTensorNC64HW64, kTensorC64RSK64, kTensorK4RSC4, + kTensorCK4RS4, + kTensorCK8RS8, + kTensorCK16RS16, kInvalid }; @@ -225,6 +228,7 @@ enum class ThreadblockSwizzleID { kConvolutionFpropNCxHWx, kConvolutionFpropTrans, kConvolutionDgradNCxHWx, + kConvolutionDgradTrans, kInvalid }; diff --git a/dnn/src/cuda/cutlass/library_internal.h b/dnn/src/cuda/cutlass/library_internal.h index a6bfb01c..b12a0a52 100644 --- a/dnn/src/cuda/cutlass/library_internal.h +++ b/dnn/src/cuda/cutlass/library_internal.h @@ -340,6 +340,21 @@ struct LayoutMap> { static LayoutTypeID const kId = LayoutTypeID::kTensorK4RSC4; }; +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorCK4RS4; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorCK8RS8; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorCK16RS16; +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// template @@ -556,6 +571,13 @@ struct ThreadblockSwizzleMap< ThreadblockSwizzleID::kConvolutionDgradNCxHWx; }; +template <> +struct ThreadblockSwizzleMap< + conv::threadblock::ConvolutionDgradTransThreadblockSwizzle> { + static ThreadblockSwizzleID const kId = + ThreadblockSwizzleID::kConvolutionDgradTrans; +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// template diff --git a/dnn/src/cuda/cutlass/util.cu b/dnn/src/cuda/cutlass/util.cu index 6efcfa98..0506826c 100644 --- a/dnn/src/cuda/cutlass/util.cu +++ b/dnn/src/cuda/cutlass/util.cu @@ -533,7 +533,10 @@ static struct { {LayoutTypeID::kTensorC16RSK16, "c16rsk16"}, {LayoutTypeID::kTensorC32RSK32, "c32rsk32"}, {LayoutTypeID::kTensorC64RSK64, "c64rsk64"}, - {LayoutTypeID::kTensorK4RSC4, "k4rsC4"}, + {LayoutTypeID::kTensorK4RSC4, "k4rsc4"}, + {LayoutTypeID::kTensorCK4RS4, "ck4rs4"}, + {LayoutTypeID::kTensorCK8RS8, "ck8rs8"}, + {LayoutTypeID::kTensorCK16RS16, "ck16rs16"}, {LayoutTypeID::kUnknown, "*"}, {LayoutTypeID::kInvalid, nullptr}}; @@ -1499,6 +1502,8 @@ static struct { ThreadblockSwizzleID::kConvolutionFpropTrans}, {"convolution_dgrad_ncxhwx", "ConvolutionDgradNCxHWxThreadblockSwizzle", ThreadblockSwizzleID::kConvolutionDgradNCxHWx}, + {"convolution_dgrad_ncxhwx", "ConvolutionDgradTransThreadblockSwizzle", + ThreadblockSwizzleID::kConvolutionDgradTrans}, }; /// Converts a ThreadblockSwizzleID enumerant to a string diff --git a/dnn/src/cuda/transpose_utils.cuh b/dnn/src/cuda/transpose_utils.cuh new file mode 100644 index 00000000..a0a286f4 --- /dev/null +++ b/dnn/src/cuda/transpose_utils.cuh @@ -0,0 +1,69 @@ +/** + * \file dnn/src/cuda/memory_utils.cuh + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#if MEGDNN_CC_CUDA +#pragma once +#include "src/cuda/utils.cuh" + +namespace megdnn { +namespace cuda { + +MEGDNN_DEVICE __forceinline__ void transpose_int8_4x4_impl( + const int src0, const int src1, const int src2, const int src3, + int& dst0, int& dst1, int& dst2, int& dst3) { + int dst01_lo = __byte_perm(src0, src1, 0x5140); + int dst01_hi = __byte_perm(src0, src1, 0x7362); + int dst23_lo = __byte_perm(src2, src3, 0x5140); + int dst23_hi = __byte_perm(src2, src3, 0x7362); + dst0 = __byte_perm(dst01_lo, dst23_lo, 0x5410); + dst1 = __byte_perm(dst01_lo, dst23_lo, 0x7632); + dst2 = __byte_perm(dst01_hi, dst23_hi, 0x5410); + dst3 = __byte_perm(dst01_hi, dst23_hi, 0x7632); +} + +template +MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4( + const int src[interleaved], vec_type (&dst)[4]); + +template <> +MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<4, int>( + const int src[4], int (&dst)[4]) { + transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0], dst[1], + dst[2], dst[3]); +} + +template <> +MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<8, int2>( + const int src[8], int2 (&dst)[4]) { + transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, dst[1].x, + dst[2].x, dst[3].x); + transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, dst[1].y, + dst[2].y, dst[3].y); +} + +template <> +MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<16, int4>( + const int src[16], int4 (&dst)[4]) { + transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, dst[1].x, + dst[2].x, dst[3].x); + transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, dst[1].y, + dst[2].y, dst[3].y); + transpose_int8_4x4_impl(src[8], src[9], src[10], src[11], dst[0].z, + dst[1].z, dst[2].z, dst[3].z); + transpose_int8_4x4_impl(src[12], src[13], src[14], src[15], dst[0].w, + dst[1].w, dst[2].w, dst[3].w); +} + +} // namespace cuda +} // namespace megdnn +#endif + +// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/test/common/convolution.cpp b/dnn/test/common/convolution.cpp index fe126888..57a711c5 100644 --- a/dnn/test/common/convolution.cpp +++ b/dnn/test/common/convolution.cpp @@ -469,7 +469,6 @@ std::vector convolution::get_args_int8_nchw4_conv_bwd_data() { return args; } - std::vector convolution::get_args_int8_nchw_conv_bwd_data() { std::vector args; param::Convolution cur_param; @@ -511,6 +510,46 @@ std::vector convolution::get_args_int8_nchw_conv_bwd_data() { return args; } +std::vector convolution::get_args_int8_nhwc_conv_bwd_data() { + std::vector args; + param::Convolution cur_param; + + // clang-format off + for (auto mode : {param::Convolution::Mode::CROSS_CORRELATION}) { + for (size_t b : {64, 16}) { + for (size_t ic : {16, 32}) { + for (size_t oc : {16, 32}) { + for (size_t h : {8}) { + for (size_t w : {8, 11}) { + for (size_t kernel_size : {3, 4, 5, 7}) { + for (int p : {0, static_cast(kernel_size / 2)}) { + for (size_t s : {2}) { + if (kernel_size >= 7) { + b = std::min(b, 32_z); + } + size_t f = kernel_size; + cur_param.mode = mode; + + cur_param.format = param::Convolution::Format::NHWC; + cur_param.sparse = param::Convolution::Sparse::DENSE; + cur_param.pad_h = cur_param.pad_w = p; + cur_param.stride_h = cur_param.stride_w = s; + + //! bias channel + args.emplace_back(cur_param, TensorShape{b, h, w, ic}, + TensorShape{oc, f, f, ic}); + } } } } } } } } } + // clang-format on + + cur_param.pad_h = cur_param.pad_w = 1; + cur_param.stride_h = cur_param.stride_w = 1; + + args.emplace_back(cur_param, TensorShape{16, 8, 11, 16}, + TensorShape{16, 3, 3, 16}); + + return args; +} + void convolution::test_conv_config_combinations( int k_size, Handle* handle, bool test_int8, bool test_backward, bool is_cuda, ConvEPSGetter eps_getter, bool use_io16xc32) { diff --git a/dnn/test/common/convolution.h b/dnn/test/common/convolution.h index 109059f8..f4e10b38 100644 --- a/dnn/test/common/convolution.h +++ b/dnn/test/common/convolution.h @@ -50,6 +50,7 @@ std::vector get_dilated_args(); std::vector get_chanwise_args(); std::vector get_args_int8_nchw4_conv_bwd_data(); std::vector get_args_int8_nchw_conv_bwd_data(); +std::vector get_args_int8_nhwc_conv_bwd_data(); //! \param stage 0 for fwd, 1 for bwd data, 2 for bwd filter using ConvEPSGetter = diff --git a/dnn/test/cuda/convolution.cpp b/dnn/test/cuda/convolution.cpp index b4172536..c88d1ac6 100644 --- a/dnn/test/cuda/convolution.cpp +++ b/dnn/test/cuda/convolution.cpp @@ -386,6 +386,69 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_INT8_NCHW_DP4A) { } } +#if CUDA_VERSION >= 10020 +TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_INT8_NHWC_IMMA) { + if (!cuda::is_compute_capability_required(7, 5)) { + printf("Skip CUDA.CONVOLUTION_BACKWARD_DATA_INT8_NHWC_IMMA test as " + "current device doesn't support\n"); + return; + } + + using namespace convolution; + std::vector args = get_args_int8_nhwc_conv_bwd_data(); + + struct AlgoParam { + int threadblock_m; + int threadblock_n; + int threadblock_k; + int warp_m; + int warp_n; + int warp_k; + int stage; + int access_size; + std::string to_string() { + return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage_%d", threadblock_m, + threadblock_n, threadblock_k, warp_m, warp_n, + warp_k, stage, access_size); + } + }; + + std::vector all_params; + + all_params.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 4}); + all_params.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 8}); + all_params.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 16}); + all_params.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 4}); + all_params.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 8}); + all_params.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 16}); + + for (auto algo_param : all_params) { + Checker checker(handle_cuda()); + std::string algo_name(ssprintf("INT8_NHWC_IMMA_IMPLICIT_GEMM%s", + algo_param.to_string().c_str())); + checker.set_before_exec_callback( + AlgoChecker(algo_name.c_str())); + + checker.set_epsilon(1 + 1e-3).set_max_avg_error(1e-1); + + for (auto&& arg : args) { + UniformIntRNG rng(-3, 3); + auto src = TensorLayout(arg.src, dtype::QuantizedS8{1.2f}); + auto filter = TensorLayout(arg.filter, dtype::QuantizedS8{1.3f}); + TensorLayout dst; + dst.dtype = dtype::QuantizedS8{1.2f}; + { + auto opr = handle_cuda()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout(src, filter, dst); + } + checker.set_rng(0, &rng).set_rng(1, &rng).set_param(arg.param).exec( + TensorLayoutArray{filter, dst, src}); + } + } +} +#endif + TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_FAILED_CUDNN7_5) { // BRAIN-481 failed on architectures 7.0, remove the following if statement, // when cudnn fixed the problem.