From 9b4b910dc11b84ed619d90b77689f025c0df8876 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 12 Jul 2021 18:17:50 +0800 Subject: [PATCH] feat(dnn/cuda): integrate cutlass operation table and replace all cutlass wrappers GitOrigin-RevId: 2a70335441e8a844dcf3c2d00bbf6db381ad9623 --- dnn/scripts/cutlass_generator/conv2d_operation.py | 73 +- dnn/scripts/cutlass_generator/gemm_operation.py | 121 +- dnn/scripts/cutlass_generator/gen_list.py | 2 + dnn/scripts/cutlass_generator/generator.py | 28 +- dnn/scripts/cutlass_generator/list.bzl | 25 + dnn/scripts/cutlass_generator/manifest.py | 39 + dnn/src/cuda/conv_bias/algo.cpp | 94 +- dnn/src/cuda/conv_bias/algo.h | 186 +-- .../cuda/conv_bias/cutlass_convolution_base.cpp | 253 ++++ .../cuda/conv_bias/cutlass_convolution_wrapper.cuh | 129 -- .../conv_bias/cutlass_convolution_wrapper_int4.cu | 595 -------- .../conv_bias/cutlass_convolution_wrapper_int8.cu | 804 ---------- .../implicit_gemm_conv_bias_cutlass_wrapper.cuinl | 65 - .../implicit_gemm_int4_int4_nchw64_imma.cpp | 26 +- .../implicit_gemm_int4_int4_nhwc_imma.cpp | 39 +- .../implicit_gemm_int4_nchw64_imma_base.cpp | 41 +- .../implicit_gemm_int4_nhwc_imma_base.cpp | 44 +- .../conv_bias/implicit_gemm_int8_nchw32_imma.cpp | 112 +- .../conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp | 139 +- .../implicit_gemm_uint4_int4_nchw64_imma.cpp | 36 +- .../implicit_gemm_uint4_int4_nhwc_imma.cpp | 48 +- dnn/src/cuda/conv_bias/opr_impl.h | 1 + .../backward_data/cutlass_deconvolution_wrapper.cu | 100 -- .../cutlass_deconvolution_wrapper.cuh | 44 - .../implicit_gemm_deconv_cutlass_wrapper.cuinl | 62 - .../implicit_gemm_int8_nchw4_dp4a.cpp | 84 +- .../backward_data/implicit_gemm_int8_nchw_dp4a.cpp | 76 +- dnn/src/cuda/cutlass/arch_mappings.h | 107 ++ dnn/src/cuda/cutlass/convolution_operation.h | 307 ++++ dnn/src/cuda/cutlass/gemm_operation.h | 202 +++ dnn/src/cuda/cutlass/initialize_all.cu | 76 + dnn/src/cuda/cutlass/library.h | 541 +++++++ dnn/src/cuda/cutlass/library_internal.h | 580 +++++++ dnn/src/cuda/cutlass/manifest.cpp | 96 ++ dnn/src/cuda/cutlass/manifest.h | 108 ++ dnn/src/cuda/cutlass/operation_table.cpp | 179 +++ dnn/src/cuda/cutlass/operation_table.h | 334 ++++ dnn/src/cuda/cutlass/singleton.cu | 72 + dnn/src/cuda/cutlass/singleton.h | 70 + dnn/src/cuda/cutlass/util.cu | 1600 ++++++++++++++++++++ dnn/src/cuda/cutlass/util.h | 218 +++ dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp | 68 +- .../matrix_mul/cutlass_float32_simt_split_k.cpp | 68 +- .../cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu | 157 -- .../cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh | 16 - .../matrix_mul/cutlass_matrix_mul_wrapper.cuinl | 57 - 46 files changed, 5456 insertions(+), 2666 deletions(-) create mode 100644 dnn/src/cuda/conv_bias/cutlass_convolution_base.cpp delete mode 100644 dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh delete mode 100644 dnn/src/cuda/conv_bias/cutlass_convolution_wrapper_int4.cu delete mode 100644 dnn/src/cuda/conv_bias/cutlass_convolution_wrapper_int8.cu delete mode 100644 dnn/src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl delete mode 100644 dnn/src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cu delete mode 100644 dnn/src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh delete mode 100644 dnn/src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl create mode 100644 dnn/src/cuda/cutlass/arch_mappings.h create mode 100644 dnn/src/cuda/cutlass/convolution_operation.h create mode 100644 dnn/src/cuda/cutlass/gemm_operation.h create mode 100644 dnn/src/cuda/cutlass/initialize_all.cu create mode 100644 dnn/src/cuda/cutlass/library.h create mode 100644 dnn/src/cuda/cutlass/library_internal.h create mode 100644 dnn/src/cuda/cutlass/manifest.cpp create mode 100644 dnn/src/cuda/cutlass/manifest.h create mode 100644 dnn/src/cuda/cutlass/operation_table.cpp create mode 100644 dnn/src/cuda/cutlass/operation_table.h create mode 100644 dnn/src/cuda/cutlass/singleton.cu create mode 100644 dnn/src/cuda/cutlass/singleton.h create mode 100644 dnn/src/cuda/cutlass/util.cu create mode 100644 dnn/src/cuda/cutlass/util.h delete mode 100644 dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu delete mode 100644 dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuinl diff --git a/dnn/scripts/cutlass_generator/conv2d_operation.py b/dnn/scripts/cutlass_generator/conv2d_operation.py index 324c2a52..1e4e43ed 100644 --- a/dnn/scripts/cutlass_generator/conv2d_operation.py +++ b/dnn/scripts/cutlass_generator/conv2d_operation.py @@ -163,7 +163,7 @@ using Convolution = ${element_bias}, ${layout_bias}, ${element_accumulator}, - ${conv_type}, + ${conv_type}, ${opcode_class}, ${arch}, cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, @@ -246,6 +246,7 @@ using Deconvolution = ${element_bias}, ${layout_bias}, ${element_accumulator}, + ${conv_type}, ${opcode_class}, ${arch}, cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, @@ -276,6 +277,7 @@ using Deconvolution = values = { 'operation_name': operation.procedural_name(), + 'conv_type': ConvTypeTag[operation.conv_type], 'element_src': DataTypeTag[operation.src.element], 'layout_src': LayoutTag[operation.src.layout], 'element_flt': DataTypeTag[operation.flt.element], @@ -530,44 +532,17 @@ void initialize_${configuration_name}(Manifest &manifest) { ################################################################################################### class EmitConvSingleKernelWrapper(): - def __init__(self, kernel_path, operation, wrapper_path): + def __init__(self, kernel_path, operation): self.kernel_path = kernel_path - self.wrapper_path = wrapper_path self.operation = operation - self.conv_wrappers = { \ - ConvKind.Fprop: """ -template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper( - const typename Convolution::ElementSrc* d_src, - const typename Convolution::ElementFilter* d_filter, - const typename Convolution::ElementBias* d_bias, - const typename Convolution::ElementDst* d_z, - typename Convolution::ElementDst* d_dst, - int* workspace, - typename Convolution::ConvolutionParameter const& conv_param, - typename Convolution::EpilogueOutputOp::Params const& epilogue, - cudaStream_t stream, - typename Convolution::ExtraParam extra_param); -""", \ - ConvKind.Dgrad: """ -template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper( - const typename Deconvolution::ElementSrc* d_src, - const typename Deconvolution::ElementFilter* d_filter, - const typename Deconvolution::ElementBias* d_bias, - const typename Deconvolution::ElementDst* d_z, - typename Deconvolution::ElementDst* d_dst, - int* workspace, - typename Deconvolution::ConvolutionParameter const& conv_param, - typename Deconvolution::EpilogueOutputOp::Params const& epilogue, - cudaStream_t stream); -""", \ - } - if self.operation.conv_kind == ConvKind.Fprop: self.instance_emitter = EmitConv2dInstance() + self.convolution_name = "Convolution" else: assert self.operation.conv_kind == ConvKind.Dgrad self.instance_emitter = EmitDeconvInstance() + self.convolution_name = "Deconvolution" self.header_template = """ #if !MEGDNN_TEGRA_X1 @@ -575,13 +550,30 @@ template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper( + "${operation_name}" + )); +} + +} // namespace library +} // namespace cutlass """ self.epilogue_template = """ @@ -593,9 +585,7 @@ ${wrapper_instance} def __enter__(self): self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) self.kernel_file = LazyFile(self.kernel_path) - self.kernel_file.write(SubstituteTemplate(self.header_template, { - 'wrapper_path': self.wrapper_path, - })) + self.kernel_file.write(self.header_template) return self # @@ -604,11 +594,12 @@ ${wrapper_instance} 'operation_instance': self.instance_emitter.emit(self.operation), })) - # emit wrapper - wrapper = SubstituteTemplate(self.wrapper_template, { - 'wrapper_instance': self.conv_wrappers[self.operation.conv_kind], + # emit manifest helper + manifest = SubstituteTemplate(self.manifest_template, { + 'operation_name': self.operation.procedural_name(), + 'convolution_name': self.convolution_name }) - self.kernel_file.write(wrapper) + self.kernel_file.write(manifest) # def __exit__(self, exception_type, exception_value, traceback): diff --git a/dnn/scripts/cutlass_generator/gemm_operation.py b/dnn/scripts/cutlass_generator/gemm_operation.py index 3cd28715..f169235c 100644 --- a/dnn/scripts/cutlass_generator/gemm_operation.py +++ b/dnn/scripts/cutlass_generator/gemm_operation.py @@ -940,8 +940,8 @@ void initialize_${configuration_name}(Manifest &manifest) { /////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace library -} // namespace cutlass +} // namespace library +} // namespace cutlass /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -995,48 +995,101 @@ void initialize_${configuration_name}(Manifest &manifest) { ################################################################################################### class EmitGemmSingleKernelWrapper: - def __init__(self, kernel_path, gemm_operation, wrapper_path): + def __init__(self, kernel_path, gemm_operation): self.kernel_path = kernel_path - self.wrapper_path = wrapper_path self.operation = gemm_operation - gemm_wrapper = """ -template void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_wrapper( - const typename Operation_${operation_name}::ElementA* d_A, size_t lda, - const typename Operation_${operation_name}::ElementB* d_B, size_t ldb, - typename Operation_${operation_name}::ElementC* d_C, size_t ldc, - int* workspace, - cutlass::gemm::GemmCoord const& problem_size, - typename Operation_${operation_name}::EpilogueOutputOp::Params const& epilogue, - cudaStream_t stream, int split_k_slices); + instance_emitters = { + GemmKind.Gemm: EmitGemmInstance(), + GemmKind.SplitKParallel: EmitGemmSplitKParallelInstance(), + } + self.instance_emitter = instance_emitters[self.operation.gemm_kind] + + self.header_template = """ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wuninitialized" +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" + +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_splitk_parallel.h" + +#include "src/cuda/cutlass/manifest.h" +#include "src/cuda/cutlass/gemm_operation.h" """ + self.instance_template = """ +${operation_instance} +""" + + self.manifest_template = """ +namespace cutlass { +namespace library { + +void initialize_${operation_name}(Manifest &manifest) { + manifest.append(new GemmOperation< + Operation_${operation_name} + >("${operation_name}")); +} + +} // namespace library +} // namespace cutlass +""" + + self.epilogue_template = """ +#pragma GCC diagnostic pop +#endif +""" + # + def __enter__(self): + self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) + self.kernel_file = LazyFile(self.kernel_path) + self.kernel_file.write(self.header_template) + return self + + # + def emit(self): + self.kernel_file.write(SubstituteTemplate(self.instance_template, { + 'operation_instance': self.instance_emitter.emit(self.operation), + })) - gemv_wrapper = """ + # emit manifest helper + manifest = SubstituteTemplate(self.manifest_template, { + 'operation_name': self.operation.procedural_name(), + }) + self.kernel_file.write(manifest) + + # + def __exit__(self, exception_type, exception_value, traceback): + self.kernel_file.write(self.epilogue_template) + self.kernel_file.close() + + +################################################################################################### +################################################################################################### + +class EmitGemvSingleKernelWrapper: + def __init__(self, kernel_path, gemm_operation, wrapper_path): + self.kernel_path = kernel_path + self.wrapper_path = wrapper_path + self.operation = gemm_operation + + self.wrapper_template = """ template void megdnn::cuda::cutlass_wrapper:: cutlass_vector_matrix_mul_batched_strided_wrapper( BatchedGemmCoord const& problem_size, - const typename Operation_${operation_name}::ElementA* d_A, size_t lda, size_t batch_stride_a, - const typename Operation_${operation_name}::ElementB* d_B, size_t ldb, size_t batch_stride_b, + const typename Operation_${operation_name}::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename Operation_${operation_name}::ElementB* d_B, size_t ldb, size_t batch_stride_b, typename Operation_${operation_name}::ElementCD* d_C, size_t ldc, size_t batch_stride_c, cudaStream_t stream); """ - if self.operation.gemm_kind == GemmKind.SplitKParallel or \ - self.operation.gemm_kind == GemmKind.Gemm: - self.wrapper_template = gemm_wrapper - else: - assert self.operation.gemm_kind == GemmKind.GemvBatchedStrided - self.wrapper_template = gemv_wrapper - - instance_emitters = { - GemmKind.Gemm: EmitGemmInstance(), - GemmKind.SplitKParallel: EmitGemmSplitKParallelInstance(), - GemmKind.GemvBatchedStrided: EmitGemvBatchedStridedInstance(), - } - self.instance_emitter = instance_emitters[self.operation.gemm_kind] + self.instance_emitter = EmitGemvBatchedStridedInstance() self.header_template = """ -#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) // ignore warning of cutlass #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-parameter" @@ -1055,10 +1108,10 @@ ${operation_instance} """ # def __enter__(self): - self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) + self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) self.kernel_file = LazyFile(self.kernel_path) self.kernel_file.write(SubstituteTemplate(self.header_template, { - 'wrapper_path': self.wrapper_path, + 'wrapper_path': self.wrapper_path, })) return self @@ -1070,7 +1123,7 @@ ${operation_instance} # emit wrapper wrapper = SubstituteTemplate(self.wrapper_template, { - 'operation_name': self.operation.procedural_name(), + 'operation_name': self.operation.procedural_name(), }) self.kernel_file.write(wrapper) @@ -1079,7 +1132,5 @@ ${operation_instance} self.kernel_file.write(self.epilogue_template) self.kernel_file.close() - ################################################################################################### ################################################################################################### - diff --git a/dnn/scripts/cutlass_generator/gen_list.py b/dnn/scripts/cutlass_generator/gen_list.py index 08ca0e88..7c61f73d 100644 --- a/dnn/scripts/cutlass_generator/gen_list.py +++ b/dnn/scripts/cutlass_generator/gen_list.py @@ -23,6 +23,8 @@ def write_op_list(f, gen_op, gen_type): operations = GenerateDeconvOperations(GenArg(gen_op, gen_type)) for op in operations: f.write(' "%s.cu",\n' % op.procedural_name()) + if gen_op != "gemv": + f.write(' "all_%s_%s_operations.cu",\n' % (gen_op, gen_type)) if __name__ == "__main__": diff --git a/dnn/scripts/cutlass_generator/generator.py b/dnn/scripts/cutlass_generator/generator.py index ed6149de..73018de6 100644 --- a/dnn/scripts/cutlass_generator/generator.py +++ b/dnn/scripts/cutlass_generator/generator.py @@ -292,7 +292,7 @@ def GenerateConv2d_TensorOp_8832(args): ] operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], dst_layout, dst_type, min_cc, 128, 128, 64, - True, ImplicitGemmMode.GemmTN, True) + False, ImplicitGemmMode.GemmTN, True) layouts_nhwc = [ (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32), @@ -633,16 +633,10 @@ if __name__ == "__main__": parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832'], default='simt', help="kernel type of CUTLASS kernel generator") - operation2wrapper_path = { - "gemm": "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuinl", \ - "gemv": "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl", \ - "conv2d": "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl", \ - "deconv": "src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl", \ - } + gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl" args = parser.parse_args() - wrapper_path = operation2wrapper_path[args.operations] if args.operations == "gemm": operations = GenerateGemmOperations(args) elif args.operations == "gemv": @@ -652,16 +646,22 @@ if __name__ == "__main__": elif args.operations == "deconv": operations = GenerateDeconvOperations(args) - if args.operations == "conv2d" or args.operations == "deconv": for operation in operations: - with EmitConvSingleKernelWrapper(args.output, operation, wrapper_path) as emitter: + with EmitConvSingleKernelWrapper(args.output, operation) as emitter: emitter.emit() - elif args.operations == "gemm" or args.operations == "gemv": + elif args.operations == "gemm": for operation in operations: - with EmitGemmSingleKernelWrapper(args.output, operation, wrapper_path) as emitter: + with EmitGemmSingleKernelWrapper(args.output, operation) as emitter: emitter.emit() - + elif args.operations == "gemv": + for operation in operations: + with EmitGemvSingleKernelWrapper(args.output, operation, gemv_wrapper_path) as emitter: + emitter.emit() + + if args.operations != "gemv": + GenerateManifest(args, operations, args.output) + # ################################################################################################### - \ No newline at end of file + diff --git a/dnn/scripts/cutlass_generator/list.bzl b/dnn/scripts/cutlass_generator/list.bzl index d84c0a3b..0d4b649a 100644 --- a/dnn/scripts/cutlass_generator/list.bzl +++ b/dnn/scripts/cutlass_generator/list.bzl @@ -137,6 +137,7 @@ cutlass_gen_list = [ "cutlass_simt_sgemm_split_k_parallel_256x32_8x2_tt_align1.cu", "cutlass_simt_sgemm_256x64_8x2_tt_align1.cu", "cutlass_simt_sgemm_split_k_parallel_256x64_8x2_tt_align1.cu", + "all_gemm_simt_operations.cu", "cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4.cu", "cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2.cu", "cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1.cu", @@ -169,6 +170,7 @@ cutlass_gen_list = [ "cutlass_simt_s8_idgrad_id_s8_16x128x16_16x64x16_2_nc4hw4_k4rsc4.cu", "cutlass_simt_s8_idgrad_id_s8_16x128x16_16x128x16_1_nc4hw4_k4rsc4.cu", "cutlass_simt_s8_idgrad_id_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", + "all_deconv_simt_operations.cu", "cutlass_simt_s8_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", "cutlass_simt_s8_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", "cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", @@ -373,6 +375,7 @@ cutlass_gen_list = [ "cutlass_simt_f32_ifprop_1x1_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", "cutlass_simt_f32_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", "cutlass_simt_f32_ifprop_1x1_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", + "all_conv2d_simt_operations.cu", "cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", @@ -481,26 +484,47 @@ cutlass_gen_list = [ "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", "cutlass_tensorop_s8_i8816fprop_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", + "all_conv2d_tensorop8816_operations.cu", "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", + "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", + "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", + "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", + "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", + "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", + "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", + "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", + "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", + "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", + "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", + "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", + "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", + "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", + "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", + "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", + "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", + "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", + "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", + "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", + "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", "cutlass_tensorop_s4_i8832fprop_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", "cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", "cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", @@ -621,4 +645,5 @@ cutlass_gen_list = [ "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", + "all_conv2d_tensorop8832_operations.cu", ] \ No newline at end of file diff --git a/dnn/scripts/cutlass_generator/manifest.py b/dnn/scripts/cutlass_generator/manifest.py index 57333f39..8ca484bd 100644 --- a/dnn/scripts/cutlass_generator/manifest.py +++ b/dnn/scripts/cutlass_generator/manifest.py @@ -8,6 +8,7 @@ import enum import os.path import shutil +from lazy_file import LazyFile from library import * from gemm_operation import * from conv2d_operation import * @@ -349,3 +350,41 @@ void initialize_all(Manifest &manifest) { # ################################################################################################### + +def GenerateManifest(args, operations, output_dir): + manifest_path = os.path.join(output_dir, "all_%s_%s_operations.cu" % (args.operations, args.type)) + f = LazyFile(manifest_path) + f.write(""" +/* + Generated by generator.py - Do not edit. +*/ + +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) + +#include "cutlass/cutlass.h" +#include "src/cuda/cutlass/library.h" +#include "src/cuda/cutlass/manifest.h" + +namespace cutlass { +namespace library { + +""") + for op in operations: + f.write("void initialize_%s(Manifest &manifest);\n" % op.procedural_name()) + + f.write(""" +void initialize_all_%s_%s_operations(Manifest &manifest) { +""" % (args.operations, args.type)) + + for op in operations: + f.write(" initialize_%s(manifest);\n" % op.procedural_name()) + + f.write(""" +} + +} // namespace library +} // namespace cutlass + +#endif +""") + f.close() diff --git a/dnn/src/cuda/conv_bias/algo.cpp b/dnn/src/cuda/conv_bias/algo.cpp index ed0a2eb7..6c687dfe 100644 --- a/dnn/src/cuda/conv_bias/algo.cpp +++ b/dnn/src/cuda/conv_bias/algo.cpp @@ -217,68 +217,77 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { #if CUDA_VERSION >= 10020 { using AlgoParam = AlgoInt8NCHW32IMMAImplicitGemm::AlgoParam; - int8_nchw32_imma.emplace_back(AlgoParam{128, 256, 64, 64, 64, 64, 2}); - int8_nchw32_imma.emplace_back(AlgoParam{256, 128, 64, 64, 64, 64, 2}); - int8_nchw32_imma.emplace_back(AlgoParam{128, 128, 64, 64, 64, 64, 2}); - int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 64, 64, 32, 64, 2}); - int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 64, 32, 64, 64, 2}); - int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 1}); - int8_nchw32_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1}); - int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 32, 32, 64, 32, 1}); - int8_nchw32_imma.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 1}); + int8_nchw32_imma.emplace_back( + AlgoParam{128, 256, 64, 64, 64, 64, 8, 8, 16, 2}); + int8_nchw32_imma.emplace_back( + AlgoParam{256, 128, 64, 64, 64, 64, 8, 8, 16, 2}); + int8_nchw32_imma.emplace_back( + AlgoParam{128, 128, 64, 64, 64, 64, 8, 8, 16, 2}); + int8_nchw32_imma.emplace_back( + AlgoParam{128, 64, 64, 64, 32, 64, 8, 8, 16, 2}); + int8_nchw32_imma.emplace_back( + AlgoParam{64, 128, 64, 32, 64, 64, 8, 8, 16, 2}); + int8_nchw32_imma.emplace_back( + AlgoParam{128, 64, 32, 64, 32, 32, 8, 8, 16, 1}); + int8_nchw32_imma.emplace_back( + AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1}); + int8_nchw32_imma.emplace_back( + AlgoParam{64, 128, 32, 32, 64, 32, 8, 8, 16, 1}); + int8_nchw32_imma.emplace_back( + AlgoParam{32, 128, 32, 32, 64, 32, 8, 8, 16, 1}); } { using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; int4_int4_nchw64_imma.emplace_back( - AlgoParam{128, 128, 128, 64, 64, 128, 2}); + AlgoParam{128, 128, 128, 64, 64, 128, 8, 8, 32, 2}); int4_int4_nchw64_imma.emplace_back( - AlgoParam{128, 256, 128, 64, 64, 128, 2}); + AlgoParam{128, 256, 128, 64, 64, 128, 8, 8, 32, 2}); int4_int4_nchw64_imma.emplace_back( - AlgoParam{128, 64, 128, 64, 64, 128, 2}); + AlgoParam{128, 64, 128, 64, 64, 128, 8, 8, 32, 2}); int4_int4_nchw64_imma.emplace_back( - AlgoParam{128, 64, 64, 64, 64, 64, 1}); + AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1}); } { using AlgoParam = AlgoUInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; uint4_int4_nchw64_imma.emplace_back( - AlgoParam{128, 128, 128, 64, 64, 128, 2}); + AlgoParam{128, 128, 128, 64, 64, 128, 8, 8, 32, 2}); uint4_int4_nchw64_imma.emplace_back( - AlgoParam{128, 256, 128, 64, 64, 128, 2}); + AlgoParam{128, 256, 128, 64, 64, 128, 8, 8, 32, 2}); uint4_int4_nchw64_imma.emplace_back( - AlgoParam{128, 64, 128, 64, 64, 128, 2}); + AlgoParam{128, 64, 128, 64, 64, 128, 8, 8, 32, 2}); uint4_int4_nchw64_imma.emplace_back( - AlgoParam{128, 64, 64, 64, 64, 64, 1}); + AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1}); } { using AlgoParam = AlgoInt4Int4NHWCIMMAImplicitGemm::AlgoParam; int4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 32, 64, 64, 32, 64, 1, 32}); + AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 32}); int4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 32, 64, 64, 32, 64, 1, 16}); + AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 16}); int4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 32, 64, 64, 32, 64, 1, 8}); + AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 8}); int4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 64, 64, 64, 64, 64, 1, 32}); + AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 32}); int4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 64, 64, 64, 64, 64, 1, 16}); + AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 16}); int4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 64, 64, 64, 64, 64, 1, 8}); + AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 8}); } { using AlgoParam = AlgoUInt4Int4NHWCIMMAImplicitGemm::AlgoParam; uint4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 32, 64, 64, 32, 64, 1, 32}); + AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 32}); uint4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 32, 64, 64, 32, 64, 1, 16}); + AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 16}); uint4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 32, 64, 64, 32, 64, 1, 8}); + AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 8}); uint4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 64, 64, 64, 64, 64, 1, 32}); + AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 32}); uint4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 64, 64, 64, 64, 64, 1, 16}); + AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 16}); uint4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 64, 64, 64, 64, 64, 1, 8}); + AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 8}); } #endif } @@ -286,15 +295,24 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() { using AlgoParam = AlgoInt8NCHW4DotProdImplicitGemm::AlgoParam; - int8_nchw4_dotprod.emplace_back(AlgoParam{128, 128, 32, 64, 32, 32, 2}); - int8_nchw4_dotprod.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 2}); - int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32, 2}); - int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 2}); - int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 2}); - int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32, 2}); - int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32, 2}); - int8_nchw4_dotprod.emplace_back(AlgoParam{16, 128, 16, 16, 128, 16, 1}); - int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 2}); + int8_nchw4_dotprod.emplace_back( + AlgoParam{128, 128, 32, 64, 32, 32, 1, 1, 4, 2}); + int8_nchw4_dotprod.emplace_back( + AlgoParam{128, 64, 32, 64, 32, 32, 1, 1, 4, 2}); + int8_nchw4_dotprod.emplace_back( + AlgoParam{64, 128, 32, 64, 32, 32, 1, 1, 4, 2}); + int8_nchw4_dotprod.emplace_back( + AlgoParam{32, 128, 32, 32, 64, 32, 1, 1, 4, 2}); + int8_nchw4_dotprod.emplace_back( + AlgoParam{128, 32, 32, 64, 32, 32, 1, 1, 4, 2}); + int8_nchw4_dotprod.emplace_back( + AlgoParam{32, 64, 32, 32, 64, 32, 1, 1, 4, 2}); + int8_nchw4_dotprod.emplace_back( + AlgoParam{64, 32, 32, 64, 32, 32, 1, 1, 4, 2}); + int8_nchw4_dotprod.emplace_back( + AlgoParam{16, 128, 16, 16, 128, 16, 1, 1, 4, 1}); + int8_nchw4_dotprod.emplace_back( + AlgoParam{16, 64, 8, 16, 64, 8, 1, 1, 4, 2}); } ConvBiasForwardImpl::AlgoBase* diff --git a/dnn/src/cuda/conv_bias/algo.h b/dnn/src/cuda/conv_bias/algo.h index fe4f044e..2682391b 100644 --- a/dnn/src/cuda/conv_bias/algo.h +++ b/dnn/src/cuda/conv_bias/algo.h @@ -28,6 +28,17 @@ #include #include +namespace cutlass { +namespace library { + +// forward declaration of cutlass library concepts, we hope that algo.h does +// not depend on cutlass headers + +class Operation; + +} // namespace library +} // namespace cutlass + namespace megdnn { namespace cuda { @@ -505,9 +516,44 @@ public: MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8) }; -class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final - : public AlgoBase { +/*********************** Cutlass Algorithms ************************/ + +/* The inheritance of cutlass algorithm classes: + * + * AlgoCutlassConvolutionBase + * + + * +--- AlgoInt8NCHW4DotProdImplicitGemm + * +--- AlgoInt8NCHW32IMMAImplicitGemm + * + + * +--- AlgoInt4NCHW64IMMAImplicitGemmBase + * +----+--- AlgoInt4Int4NCHW64IMMAImplicitGemm + * +----+--- AlgoUInt4Int4NCHW64IMMAImplicitGemm + * + + * +--- AlgoInt4NHWCIMMAImplicitGemmBase + * +----+--- AlgoInt4Int4NHWCIMMAImplicitGemm + * +----+--- AlgoUInt4Int4NHWCIMMAImplicitGemm + * + + */ + +/* + * The base class for all cutlass algorithm classes + */ +class ConvBiasForwardImpl::AlgoCutlassConvolutionBase : public AlgoBase { public: + // corresponds to cutlass::conv::Operator. we hope that algo.h does not + // depend on cutlass headers + enum class ConvOperator { kFprop, kDgrad, kWgrad }; + + // corresponds to cutlass::conv::ConvType. we hope that algo.h does not + // depend on cutlass headers + enum class ConvType { + kConvolution, + kBatchConvolution, + kLocal, + kLocalShare + }; + + // common parameters for operation selection struct AlgoParam { int threadblock_m; int threadblock_n; @@ -515,21 +561,54 @@ public: int warp_m; int warp_n; int warp_k; + int instruction_m; + int instruction_n; + int instruction_k; int stage; - std::string to_string() { - /// default algorithm - if (threadblock_m == 128 && threadblock_n == 128 && - threadblock_k == 32 && warp_m == 32 && warp_n == 64 && - warp_k == 32 && stage == 2) { - return ""; - } - return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, - threadblock_n, threadblock_k, warp_m, warp_n, - warp_k, stage); - } + int access_size; + + AlgoParam(int threadblock_m_, int threadblock_n_, int threadblock_k_, + int warp_m_, int warp_n_, int warp_k_, int instruction_m_, + int instruction_n_, int instruction_k_, int stage_, + int access_size_ = 0); + + std::string to_string() const; }; + + AlgoCutlassConvolutionBase(AlgoParam algo_param) + : m_algo_param{algo_param} {} + + // generate a cutlass::library::ConvolutionKey and find the corresponding + // operation (cutlass kernel) from the global OperationTable + const cutlass::library::Operation* get_cutlass_conv_op( + const SizeArgs& args, ConvOperator conv_op, ConvType conv_type, + bool load_from_const, bool without_shared_load) const; + + // execute the cutlass kernel found by get_cutlass_conv_op. we give + // subclasses full freedom to decide where and how these arguments are + // extracted + void execute_cutlass_conv_op(const cutlass::library::Operation* op, + const void* src, const void* filter, + const void* bias, const void* z, void* dst, + void* workspace, size_t n, size_t hi, + size_t wi, size_t ci, size_t co, size_t fh, + size_t fw, size_t ho, size_t wo, size_t ph, + size_t pw, size_t sh, size_t sw, size_t dh, + size_t dw, const void* alpha, const void* beta, + const void* gamma, const void* delta, + const void* theta, const void* threshold, + const void* dst_scale, cudaStream_t stream, + const void* extra_param = nullptr) const; + +protected: + AlgoParam m_algo_param; +}; + +class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final + : public AlgoCutlassConvolutionBase { +public: AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param) - : m_algo_param{algo_param}, + : AlgoCutlassConvolutionBase(algo_param), m_name{ssprintf("INT8_NCHW4_DOTPROD_IMPLICIT_GEMM%s", m_algo_param.to_string().c_str())} {} bool is_available(const SizeArgs& args) const override; @@ -555,7 +634,6 @@ public: private: WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const; - AlgoParam m_algo_param; std::string m_name; }; @@ -714,19 +792,10 @@ private: #if CUDA_VERSION >= 10020 class ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm final - : public AlgoBase { + : public AlgoCutlassConvolutionBase { public: - struct AlgoParam { - int threadblock_m; - int threadblock_n; - int threadblock_k; - int warp_m; - int warp_n; - int warp_k; - int stage; - }; AlgoInt8NCHW32IMMAImplicitGemm(AlgoParam algo_param) - : m_algo_param{algo_param} { + : AlgoCutlassConvolutionBase(algo_param) { m_name = ConvBias::algo_name( ssprintf("INT8_NCHW32_IMMA_IMPLICIT_GEMM_%s", to_string(m_algo_param).c_str()), @@ -757,25 +826,14 @@ private: WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const; - AlgoParam m_algo_param; std::string m_name; }; class ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase - : public AlgoBase { + : public AlgoCutlassConvolutionBase { public: - struct AlgoParam { - int threadblock_m; - int threadblock_n; - int threadblock_k; - int warp_m; - int warp_n; - int warp_k; - int stage; - }; - AlgoInt4NCHW64IMMAImplicitGemmBase(AlgoParam algo_param) - : m_algo_param(algo_param) {} + : AlgoCutlassConvolutionBase(algo_param) {} AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; @@ -799,16 +857,9 @@ protected: virtual std::tuple get_constants( const ExecArgs& args) const = 0; - virtual void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, - void* z_ptr, convolution::ConvParam kern_param, - uint32_t nonlinear_mode, float alpha, float beta, - float gamma, float delta, float theta, - cudaStream_t stream) const = 0; - void reorder_filter(const ExecArgs& args, void* reordered_filter) const; std::string m_name; - AlgoParam m_algo_param; }; class ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm final @@ -842,11 +893,6 @@ private: std::tuple get_constants( const ExecArgs& args) const override; - - void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, - void* z_ptr, convolution::ConvParam kern_param, - uint32_t nonlinear_mode, float alpha, float beta, float gamma, - float delta, float theta, cudaStream_t stream) const override; }; class ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm final @@ -881,30 +927,15 @@ private: std::tuple get_constants( const ExecArgs& args) const override; - void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, - void* z_ptr, convolution::ConvParam kern_param, - uint32_t nonlinear_mode, float alpha, float beta, float gamma, - float delta, float theta, cudaStream_t stream) const override; - void update_bias(const ExecArgs& args, void* updated_bias, void* reduce_filter_ptr, void* reduce_workspace) const; }; -class ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase : public AlgoBase { +class ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase + : public AlgoCutlassConvolutionBase { public: - struct AlgoParam { - int threadblock_m; - int threadblock_n; - int threadblock_k; - int warp_m; - int warp_n; - int warp_k; - int stage; - int access_size; - }; - AlgoInt4NHWCIMMAImplicitGemmBase(AlgoParam algo_param) - : m_algo_param(algo_param) {} + : AlgoCutlassConvolutionBase(algo_param) {} AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; @@ -928,17 +959,10 @@ protected: virtual std::tuple get_constants( const ExecArgs& args) const = 0; - virtual void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, - void* z_ptr, convolution::ConvParam kern_param, - uint32_t nonlinear_mode, float alpha, float beta, - float gamma, float delta, float theta, - cudaStream_t stream) const = 0; - void reorder_filter(const ExecArgs& args, int interleaved, void* reordered_filter) const; std::string m_name; - AlgoParam m_algo_param; }; class ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm final @@ -971,11 +995,6 @@ private: std::tuple get_constants( const ExecArgs& args) const override; - - void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, - void* z_ptr, convolution::ConvParam kern_param, - uint32_t nonlinear_mode, float alpha, float beta, float gamma, - float delta, float theta, cudaStream_t stream) const override; }; class ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm final @@ -1009,11 +1028,6 @@ private: std::tuple get_constants( const ExecArgs& args) const override; - void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, - void* z_ptr, convolution::ConvParam kern_param, - uint32_t nonlinear_mode, float alpha, float beta, float gamma, - float delta, float theta, cudaStream_t stream) const override; - void update_bias(const ExecArgs& args, void* updated_bias, void* reduce_filter_ptr, void* reduce_workspace) const; }; diff --git a/dnn/src/cuda/conv_bias/cutlass_convolution_base.cpp b/dnn/src/cuda/conv_bias/cutlass_convolution_base.cpp new file mode 100644 index 00000000..4c4f04a2 --- /dev/null +++ b/dnn/src/cuda/conv_bias/cutlass_convolution_base.cpp @@ -0,0 +1,253 @@ +/** + * \file dnn/src/cuda/conv_bias/cutlass_convolution_base.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/cuda/conv_bias/algo.h" +#include "src/cuda/cutlass/singleton.h" + +namespace megdnn { +namespace cuda { + +using namespace cutlass::library; +using namespace cutlass::epilogue; + +ConvBiasForwardImpl::AlgoCutlassConvolutionBase::AlgoParam::AlgoParam( + int threadblock_m_, int threadblock_n_, int threadblock_k_, int warp_m_, + int warp_n_, int warp_k_, int instruction_m_, int instruction_n_, + int instruction_k_, int stage_, int access_size_) + : threadblock_m(threadblock_m_), + threadblock_n(threadblock_n_), + threadblock_k(threadblock_k_), + warp_m(warp_m_), + warp_n(warp_n_), + warp_k(warp_k_), + instruction_m(instruction_m_), + instruction_n(instruction_m_), + instruction_k(instruction_k_), + stage(stage_), + access_size(access_size_) {} + +std::string +ConvBiasForwardImpl::AlgoCutlassConvolutionBase::AlgoParam::to_string() const { + /// default algorithm + if (threadblock_m == 128 && threadblock_n == 128 && threadblock_k == 32 && + warp_m == 32 && warp_n == 64 && warp_k == 32 && stage == 2) { + return ""; + } + return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n, + threadblock_k, warp_m, warp_n, warp_k, stage); +} + +namespace { + +using Base = ConvBiasForwardImpl::AlgoCutlassConvolutionBase; + +cutlass::conv::Operator convert_conv_op(Base::ConvOperator conv_op) { + switch (conv_op) { + case Base::ConvOperator::kFprop: + return cutlass::conv::Operator::kFprop; + case Base::ConvOperator::kDgrad: + return cutlass::conv::Operator::kDgrad; + case Base::ConvOperator::kWgrad: + return cutlass::conv::Operator::kWgrad; + default: + megdnn_assert(0, "invalid conv op"); + } +} + +cutlass::conv::ConvType convert_conv_type(Base::ConvType conv_type) { + switch (conv_type) { + case Base::ConvType::kConvolution: + return cutlass::conv::ConvType::kConvolution; + case Base::ConvType::kBatchConvolution: + return cutlass::conv::ConvType::kBatchConvolution; + case Base::ConvType::kLocal: + return cutlass::conv::ConvType::kLocal; + case Base::ConvType::kLocalShare: + return cutlass::conv::ConvType::kLocalShare; + default: + megdnn_assert(0, "invalid conv type"); + } +} + +NumericTypeID convert_dtype(DTypeEnum dtype) { + switch (dtype) { + case DTypeEnum::Float32: + return NumericTypeID::kF32; + case DTypeEnum::Float16: + return NumericTypeID::kF16; + case DTypeEnum::Int8: + return NumericTypeID::kS8; + case DTypeEnum::QuantizedS32: + return NumericTypeID::kS32; + case DTypeEnum::QuantizedS8: + return NumericTypeID::kS8; + case DTypeEnum::QuantizedS4: + return NumericTypeID::kS4; + case DTypeEnum::Quantized4Asymm: + return NumericTypeID::kU4; + default: + megdnn_assert(0, "invalid dtype"); + } +} + +struct LayoutPack { + LayoutTypeID src; + LayoutTypeID filter; + LayoutTypeID dst; + LayoutTypeID bias; +}; + +LayoutPack get_layout_pack(const param::ConvBias::Format format, + int access_type) { + using Format = param::ConvBias::Format; + + switch (format) { + case Format::NCHW4: + return {LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorC4RSK4, + LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorNC4HW4}; + case Format::NCHW4_NCHW: + return {LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorC4RSK4, + LayoutTypeID::kTensorNCHW, LayoutTypeID::kTensorNCHW}; + case Format::NCHW4_NHWC: + return {LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorC4RSK4, + LayoutTypeID::kTensorNHWC, LayoutTypeID::kTensorNHWC}; + case Format::NCHW4_NCHW32: + return {LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorC4RSK4, + LayoutTypeID::kTensorNC32HW32, + LayoutTypeID::kTensorNC32HW32}; + case Format::NCHW32: + return {LayoutTypeID::kTensorNC32HW32, + LayoutTypeID::kTensorC32RSK32, + LayoutTypeID::kTensorNC32HW32, + LayoutTypeID::kTensorNC32HW32}; + case Format::NCHW32_NCHW4: + return {LayoutTypeID::kTensorNC32HW32, + LayoutTypeID::kTensorC32RSK32, LayoutTypeID::kTensorNC4HW4, + LayoutTypeID::kTensorNC4HW4}; + case Format::NCHW64: + return {LayoutTypeID::kTensorNC64HW64, + LayoutTypeID::kTensorC64RSK64, + LayoutTypeID::kTensorNC64HW64, + LayoutTypeID::kTensorNC64HW64}; + case Format::NHWC: + switch (access_type) { + case 8: + return {LayoutTypeID::kTensorNHWC, + LayoutTypeID::kTensorNC8HW8, + LayoutTypeID::kTensorNHWC, + LayoutTypeID::kTensorNHWC}; + case 16: + return {LayoutTypeID::kTensorNHWC, + LayoutTypeID::kTensorNC16HW16, + LayoutTypeID::kTensorNHWC, + LayoutTypeID::kTensorNHWC}; + case 32: + return {LayoutTypeID::kTensorNHWC, + LayoutTypeID::kTensorNC32HW32, + LayoutTypeID::kTensorNHWC, + LayoutTypeID::kTensorNHWC}; + default: + megdnn_assert(0, "invalid access_type"); + } + default: + megdnn_assert(0, "invalid format"); + } +} + +EpilogueType get_epilogue_type(const param::ConvBias::NonlineMode mode, + bool clamp) { + using NonlineMode = param::ConvBias::NonlineMode; + + if (clamp) { + if (mode == NonlineMode::IDENTITY) { + return EpilogueType::kBiasAddLinearCombinationClamp; + } else if (mode == NonlineMode::RELU) { + return EpilogueType::kBiasAddLinearCombinationReluClamp; + } else if (mode == NonlineMode::H_SWISH) { + return EpilogueType::kBiasAddLinearCombinationHSwishClamp; + } + } else { + if (mode == NonlineMode::IDENTITY) { + return EpilogueType::kBiasAddLinearCombination; + } else if (mode == NonlineMode::RELU) { + return EpilogueType::kBiasAddLinearCombinationRelu; + } else if (mode == NonlineMode::H_SWISH) { + return EpilogueType::kBiasAddLinearCombinationHSwish; + } + } + megdnn_assert(0, "invalid nonlinear mode"); +} + +} // namespace + +const Operation* +ConvBiasForwardImpl::AlgoCutlassConvolutionBase::get_cutlass_conv_op( + const SizeArgs& args, ConvOperator conv_op, ConvType conv_type, + bool load_from_const, bool without_shared_load) const { + using Format = param::ConvBias::Format; + auto&& param = args.opr->param(); + auto layouts = get_layout_pack(param.format, m_algo_param.access_size); + auto epilogue_type = get_epilogue_type(param.nonlineMode, + param.format != Format::NCHW4_NCHW); + ConvolutionKey key{convert_conv_op(conv_op), + convert_dtype(args.src_layout->dtype.enumv()), + layouts.src, + convert_dtype(args.filter_layout->dtype.enumv()), + layouts.filter, + convert_dtype(args.dst_layout->dtype.enumv()), + layouts.dst, + convert_dtype(args.bias_layout->dtype.enumv()), + layouts.bias, + convert_conv_type(conv_type), + m_algo_param.threadblock_m, + m_algo_param.threadblock_n, + m_algo_param.threadblock_k, + m_algo_param.warp_m, + m_algo_param.warp_n, + m_algo_param.warp_k, + m_algo_param.instruction_m, + m_algo_param.instruction_n, + m_algo_param.instruction_k, + epilogue_type, + m_algo_param.stage, + load_from_const, + without_shared_load}; + + return Singleton::get().operation_table.find_op(key); +} + +void ConvBiasForwardImpl::AlgoCutlassConvolutionBase::execute_cutlass_conv_op( + const Operation* op, const void* src, const void* filter, + const void* bias, const void* z, void* dst, void* workspace, size_t n, + size_t hi, size_t wi, size_t ci, size_t co, size_t fh, size_t fw, + size_t ho, size_t wo, size_t ph, size_t pw, size_t sh, size_t sw, + size_t dh, size_t dw, const void* alpha, const void* beta, + const void* gamma, const void* delta, const void* theta, + const void* threshold, const void* dst_scale, cudaStream_t stream, + const void* extra_param) const { + // gcc prints warnings when size_t values are implicitly narrowed to int + cutlass::conv::Conv2dProblemSize problem_size{ + int(n), int(hi), int(wi), int(ci), + int(co), int(fh), int(fw), int(ho), + int(wo), int(ph), int(pw), int(sh), + int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation}; + + ConvolutionArguments conv_args{ + problem_size, src, filter, bias, z, + dst, alpha, beta, gamma, delta, + theta, threshold, dst_scale, extra_param}; + + cutlass_check(op->run(&conv_args, workspace, stream)); +} + +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh deleted file mode 100644 index 424cdd61..00000000 --- a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh +++ /dev/null @@ -1,129 +0,0 @@ -/** - * \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ -#pragma once -#include "cutlass/gemm/gemm.h" -#include "src/cuda/convolution_helper/parameter.cuh" -#include "src/cuda/utils.cuh" - -namespace megdnn { -namespace cuda { -namespace cutlass_wrapper { - -using GemmCoord = cutlass::gemm::GemmCoord; - -template -void cutlass_convolution_wrapper( - const typename Convolution::ElementSrc* d_src, - const typename Convolution::ElementFilter* d_filter, - const typename Convolution::ElementBias* d_bias, - const typename Convolution::ElementDst* d_z, - typename Convolution::ElementDst* d_dst, int* workspace, - typename Convolution::ConvolutionParameter const& conv_param, - typename Convolution::EpilogueOutputOp::Params const& epilogue, - cudaStream_t stream, typename Convolution::ExtraParam extra_param = {}); - -template -void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( - const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, - const int8_t* d_z, int8_t* d_dst, int* workspace, - const convolution::ConvParam& param, uint32_t nonlinear_mode, - float alpha, float beta, float gamma, float scale, - const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, - int stages, cudaStream_t stream); - -template -void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( - const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, - const int8_t* d_z, int8_t* d_dst, int* workspace, - const convolution::ConvParam& param, uint32_t nonlinear_mode, - float alpha, float beta, float gamma, float scale, - const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, - int stages, cudaStream_t stream); - -template -void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4( - const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, - const int8_t* d_z, int8_t* d_dst, int* workspace, - const convolution::ConvParam& param, uint32_t nonlinear_mode, - float alpha, float beta, float gamma, float scale, - const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, - int stages, cudaStream_t stream); - -template -void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw( - const int8_t* d_src, const int8_t* d_filter, const float* d_bias, - const float* d_z, float* d_dst, int* workspace, - const convolution::ConvParam& param, uint32_t nonlinear_mode, - float alpha, float beta, float gamma, float scale, - const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, - int stages, cudaStream_t stream); - -template -void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32( - const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, - const int8_t* d_z, int8_t* d_dst, int* workspace, - const convolution::ConvParam& param, uint32_t nonlinear_mode, - float alpha, float beta, float gamma, float scale, - const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, - int stages, cudaStream_t stream); - -template -void do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( - const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, - const int8_t* d_z, int8_t* d_dst, int* workspace, - const convolution::ConvParam& param, uint32_t nonlinear_mode, - float alpha, float beta, float gamma, float scale, - const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, - int stages, cudaStream_t stream); - -template -void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( - const uint8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, - const uint8_t* d_z, uint8_t* d_dst, int* workspace, - const convolution::ConvParam& param, uint32_t nonlinear_mode, - float alpha, float beta, float gamma, float delta, float theta, - float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, - const GemmCoord& warp_shape, int stages, cudaStream_t stream); - -template -void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( - const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, - const int8_t* d_z, int8_t* d_dst, int* workspace, - const convolution::ConvParam& param, uint32_t nonlinear_mode, - float alpha, float beta, float gamma, float delta, float theta, - float scale, const GemmCoord& threadblock_shape, - const GemmCoord& warp_shape, int stages, cudaStream_t stream); - -template -void do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( - const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, - const int8_t* d_z, int8_t* d_dst, int* workspace, - const convolution::ConvParam& param, uint32_t nonlinear_mode, - float alpha, float beta, float gamma, float scale, - const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, - const int32_t access_size, int stages, cudaStream_t stream); - -template -void do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( - const uint8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, - const uint8_t* d_z, uint8_t* d_dst, int* workspace, - const convolution::ConvParam& param, uint32_t nonlinear_mode, - float alpha, float beta, float gamma, float delta, float theta, - float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, - const GemmCoord& warp_shape, const int32_t access_size, int stages, - cudaStream_t stream); - -} // namespace cutlass_wrapper -} // namespace cuda -} // namespace megdnn - -// vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper_int4.cu b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper_int4.cu deleted file mode 100644 index d29a41a5..00000000 --- a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper_int4.cu +++ /dev/null @@ -1,595 +0,0 @@ -/** - * \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ -// ignore warning of cutlass -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#pragma GCC diagnostic ignored "-Wstrict-aliasing" - -#if !MEGDNN_TEGRA_X1 -#include "cutlass/convolution/device/convolution.h" -#endif -#include "src/common/opr_param_defs_enumv.cuh" -#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" -#pragma GCC diagnostic pop - -using namespace megdnn; -using namespace cuda; -using namespace cutlass_wrapper; - -/* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */ - -#if MEGDNN_TEGRA_X1 -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( - const int8_t* /* d_src */, const int8_t* /* d_filter */, - const int32_t* /* d_bias */, const int8_t* /* d_z */, - int8_t* /* d_dst */, int* /* workspace */, - const convolution::ConvParam& /* param */, - uint32_t /* nonlinear_mode */, float /* alpha */, - float /* beta */, float /* gamma */, float /* scale */, - const GemmCoord& /* threadblock_shape */, - const GemmCoord& /* warp_shape */, int /* stages */, - cudaStream_t /* stream */) {} -#else -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( - const int8_t* d_src, const int8_t* d_filter, - const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, - int* workspace, const convolution::ConvParam& param, - uint32_t nonlinear_mode, float alpha, float beta, float gamma, - float scale, const GemmCoord& threadblock_shape, - const GemmCoord& warp_shape, int stages, cudaStream_t stream) { -#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ - threadblock_k_, warp_m_, warp_n_, \ - warp_k_, stage_) \ - if (threadblock_shape.m() == threadblock_m_ && \ - threadblock_shape.n() == threadblock_n_ && \ - threadblock_shape.k() == threadblock_k_ && \ - warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ - warp_shape.k() == warp_k_ && stages == stage_) { \ - using ThreadBlockShape = \ - cutlass::gemm::GemmShape; \ - using WarpShape = cutlass::gemm::GemmShape; \ - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ - using Convolution = cutlass::conv::device::Convolution< \ - cutlass::int4b_t, cutlass::layout::TensorNCxHWx<64>, \ - cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ - ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ - cutlass::layout::TensorNCxHWx<64>, int32_t, \ - cutlass::conv::ConvType::kConvolution, \ - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ - ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ - cutlass::conv::threadblock:: \ - ConvolutionFpropTransThreadblockSwizzle, \ - stage_, 32, 32, NeedLoadFromConstMem, \ - cutlass::arch::OpMultiplyAddSaturate, \ - cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ - typename Convolution::ConvolutionParameter conv_param( \ - param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ - param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ - param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ - return cutlass_convolution_wrapper( \ - reinterpret_cast(d_src), \ - reinterpret_cast(d_filter), d_bias, \ - reinterpret_cast(d_z), \ - reinterpret_cast(d_dst), workspace, \ - conv_param, epilogue, stream); \ - } -#define DISPATCH_KERNEL \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128, 2); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 128, 64, 64, 128, 2); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 128, 64, 64, 128, 2); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1); \ - megdnn_assert(false, \ - "unsupported threadblock shape (%dx%dx%d) and warp shape " \ - "(%dx%dx%d)", \ - threadblock_shape.m(), threadblock_shape.n(), \ - threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ - warp_shape.k()); - using ElementOutput = cutlass::int4b_t; - using ElementAccumulator = int32_t; - using ElementBias = int32_t; - using ElementCompute = float; - using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; - switch (nonlinear_mode) { - case NonlineMode::IDENTITY: { - using EpilogueOp = - cutlass::epilogue::thread::BiasAddLinearCombinationClamp< - ElementOutput, 16, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma}; - DISPATCH_KERNEL; - } - case NonlineMode::RELU: { - using EpilogueOp = cutlass::epilogue::thread:: - BiasAddLinearCombinationReluClamp< - ElementOutput, 16, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; - DISPATCH_KERNEL; - } - case NonlineMode::H_SWISH: { - using EpilogueOp = cutlass::epilogue::thread:: - BiasAddLinearCombinationHSwishClamp< - ElementOutput, 16, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; - DISPATCH_KERNEL; - } - default: - megdnn_assert(false, - "unsupported nonlinear mode for conv bias operator"); - } -#undef DISPATCH_KERNEL_WITH_TILE_SHAPE -#undef DISPATCH_KERNEL -} -#endif - -#define INST(need_load_from_const_mem) \ - template void megdnn::cuda::cutlass_wrapper:: \ - do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< \ - need_load_from_const_mem>( \ - const int8_t* d_src, const int8_t* d_filter, \ - const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ - int* workspace, const convolution::ConvParam& param, \ - uint32_t nonlinear_mode, float alpha, float beta, \ - float gamma, float scale, \ - const GemmCoord& threadblock_shape, \ - const GemmCoord& warp_shape, int stages, \ - cudaStream_t stream); -INST(true); -#undef INST - -/* ====== cutlass kernel wrapper for uint4 x int4 nchw64 layout ====== */ - -#if MEGDNN_TEGRA_X1 -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( - const uint8_t* /* d_src */, const int8_t* /* d_filter */, - const int32_t* /* d_bias */, const uint8_t* /* d_z */, - uint8_t* /* d_dst */, int* /* workspace */, - const convolution::ConvParam& /* param */, - uint32_t /* nonlinear_mode */, float /* alpha */, - float /* beta */, float /* gamma */, float /* delta */, - float /* theta */, float /* scale */, - uint8_t /* src_zero_point */, - const GemmCoord& /* threadblock_shape */, - const GemmCoord& /* warp_shape */, int /* stages */, - cudaStream_t /* stream */) {} -#else -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( - const uint8_t* d_src, const int8_t* d_filter, - const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, - int* workspace, const convolution::ConvParam& param, - uint32_t nonlinear_mode, float alpha, float beta, float gamma, - float delta, float theta, float /* scale */, - uint8_t src_zero_point, const GemmCoord& threadblock_shape, - const GemmCoord& warp_shape, int stages, cudaStream_t stream) { -#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ - threadblock_k_, warp_m_, warp_n_, \ - warp_k_, stage_) \ - if (threadblock_shape.m() == threadblock_m_ && \ - threadblock_shape.n() == threadblock_n_ && \ - threadblock_shape.k() == threadblock_k_ && \ - warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ - warp_shape.k() == warp_k_ && stages == stage_) { \ - using ThreadBlockShape = \ - cutlass::gemm::GemmShape; \ - using WarpShape = cutlass::gemm::GemmShape; \ - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ - using Convolution = cutlass::conv::device::Convolution< \ - cutlass::uint4b_t, cutlass::layout::TensorNCxHWx<64>, \ - cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ - ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ - cutlass::layout::TensorNCxHWx<64>, int32_t, \ - cutlass::conv::ConvType::kConvolution, \ - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ - ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ - cutlass::conv::threadblock:: \ - ConvolutionFpropTransThreadblockSwizzle, \ - stage_, 32, 32, NeedLoadFromConstMem, \ - cutlass::arch::OpMultiplyAddSaturate, \ - cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ - typename Convolution::ConvolutionParameter conv_param( \ - param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ - param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ - param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ - return cutlass_convolution_wrapper( \ - reinterpret_cast(d_src), \ - reinterpret_cast(d_filter), d_bias, \ - reinterpret_cast(d_z), \ - reinterpret_cast(d_dst), workspace, \ - conv_param, epilogue, stream, {src_zero_point}); \ - } -#define DISPATCH_KERNEL \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128, 2); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 128, 64, 64, 128, 2); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 128, 64, 64, 128, 2); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1); \ - megdnn_assert(false, \ - "unsupported threadblock shape (%dx%dx%d) and warp shape " \ - "(%dx%dx%d)", \ - threadblock_shape.m(), threadblock_shape.n(), \ - threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ - warp_shape.k()); - using ElementOutput = cutlass::uint4b_t; - using ElementAccumulator = int32_t; - using ElementBias = int32_t; - using ElementCompute = float; - using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; - switch (nonlinear_mode) { - case NonlineMode::IDENTITY: { - using EpilogueOp = - cutlass::epilogue::thread::BiasAddLinearCombinationClamp< - ElementOutput, 16, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, - delta + theta}; - DISPATCH_KERNEL; - } - case NonlineMode::RELU: { - using EpilogueOp = cutlass::epilogue::thread:: - BiasAddLinearCombinationReluClamp< - ElementOutput, 16, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, - 0, delta, theta}; - DISPATCH_KERNEL; - } - default: - megdnn_assert(false, - "unsupported nonlinear mode for conv bias operator"); - } -#undef DISPATCH_KERNEL_WITH_TILE_SHAPE -#undef DISPATCH_KERNEL -} -#endif - -#define INST(need_load_from_const_mem) \ - template void megdnn::cuda::cutlass_wrapper:: \ - do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< \ - need_load_from_const_mem>( \ - const uint8_t* d_src, const int8_t* d_filter, \ - const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ - int* workspace, const convolution::ConvParam& param, \ - uint32_t nonlinear_mode, float alpha, float beta, \ - float gamma, float delta, float theta, float scale, \ - uint8_t src_zero_point, \ - const GemmCoord& threadblock_shape, \ - const GemmCoord& warp_shape, int stages, \ - cudaStream_t stream); -INST(true); -#undef INST - -/* ====== cutlass kernel wrapper for int4 x int4 nhwc layout ====== */ - -#if MEGDNN_TEGRA_X1 -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( - const int8_t* /* d_src */, const int8_t* /* d_filter */, - const int32_t* /* d_bias */, const int8_t* /* d_z */, - int8_t* /* d_dst */, int* /* workspace */, - const convolution::ConvParam& /* param */, - uint32_t /* nonlinear_mode */, float /* alpha */, - float /* beta */, float /* gamma */, float /* scale */, - const GemmCoord& /* threadblock_shape */, - const GemmCoord& /* warp_shape */, - const int32_t /* access_size */, int /* stages */, - cudaStream_t /* stream */) {} -#else -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( - const int8_t* d_src, const int8_t* d_filter, - const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, - int* workspace, const convolution::ConvParam& param, - uint32_t nonlinear_mode, float alpha, float beta, float gamma, - float scale, const GemmCoord& threadblock_shape, - const GemmCoord& warp_shape, const int32_t access_size, - int stages, cudaStream_t stream) { - bool without_shared_load = - ((param.co % threadblock_shape.n() == 0) && - (threadblock_shape.n() == 32 || threadblock_shape.n() == 64)); - int out_elements_per_access = - without_shared_load ? threadblock_shape.n() / 4 : 8; - -#define RUN_CUTLASS_WRAPPER(stage_, access_size_, without_shared_load_) \ - using Convolution = cutlass::conv::device::Convolution< \ - cutlass::int4b_t, cutlass::layout::TensorNHWC, cutlass::int4b_t, \ - cutlass::layout::TensorNCxHWx, ElementOutput, \ - cutlass::layout::TensorNHWC, int32_t, cutlass::layout::TensorNHWC, \ - int32_t, cutlass::conv::ConvType::kConvolution, \ - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ - ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ - cutlass::conv::threadblock:: \ - ConvolutionFpropTransThreadblockSwizzle, \ - stage_, access_size_, access_size_, NeedLoadFromConstMem, \ - cutlass::arch::OpMultiplyAddSaturate, \ - cutlass::conv::ImplicitGemmMode::GEMM_TN, without_shared_load_>; \ - typename Convolution::ConvolutionParameter conv_param( \ - param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ - param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ - param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ - return cutlass_convolution_wrapper( \ - reinterpret_cast(d_src), \ - reinterpret_cast(d_filter), d_bias, \ - reinterpret_cast(d_z), \ - reinterpret_cast(d_dst), workspace, conv_param, \ - epilogue, stream); -#define DISPATCH_KERNEL_WITH_TILE_SHAPE( \ - threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ - warp_k_, stage_, access_size_, out_elements_per_access_, \ - without_shared_load_) \ - if (threadblock_shape.m() == threadblock_m_ && \ - threadblock_shape.n() == threadblock_n_ && \ - threadblock_shape.k() == threadblock_k_ && \ - warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ - warp_shape.k() == warp_k_ && stages == stage_ && \ - access_size == access_size_ && \ - out_elements_per_access == out_elements_per_access_ && \ - without_shared_load == without_shared_load_) { \ - using ThreadBlockShape = \ - cutlass::gemm::GemmShape; \ - using WarpShape = cutlass::gemm::GemmShape; \ - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ - using ElementOutput = cutlass::int4b_t; \ - using ElementAccumulator = int32_t; \ - using ElementBias = int32_t; \ - using ElementCompute = float; \ - using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; \ - switch (nonlinear_mode) { \ - case NonlineMode::IDENTITY: { \ - using EpilogueOp = cutlass::epilogue::thread:: \ - BiasAddLinearCombinationClamp< \ - ElementOutput, out_elements_per_access_, \ - ElementAccumulator, ElementBias, \ - ElementCompute>; \ - typename EpilogueOp::Params epilogue{alpha, beta, gamma}; \ - RUN_CUTLASS_WRAPPER(stage_, access_size_, \ - without_shared_load_); \ - } \ - case NonlineMode::RELU: { \ - using EpilogueOp = cutlass::epilogue::thread:: \ - BiasAddLinearCombinationReluClamp< \ - ElementOutput, out_elements_per_access_, \ - ElementAccumulator, ElementBias, \ - ElementCompute>; \ - typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; \ - RUN_CUTLASS_WRAPPER(stage_, access_size_, \ - without_shared_load_); \ - } \ - case NonlineMode::H_SWISH: { \ - using EpilogueOp = cutlass::epilogue::thread:: \ - BiasAddLinearCombinationHSwishClamp< \ - ElementOutput, out_elements_per_access_, \ - ElementAccumulator, ElementBias, \ - ElementCompute>; \ - typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ - scale}; \ - RUN_CUTLASS_WRAPPER(stage_, access_size_, \ - without_shared_load_); \ - } \ - default: \ - megdnn_assert( \ - false, \ - "unsupported nonlinear mode for conv bias operator"); \ - } \ - } -#define DISPATCH_KERNEL \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, false); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, false); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, false); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 8, false); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 8, false); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 8, false); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, true); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, true); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, true); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 16, true); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 16, true); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 16, true); \ - megdnn_assert(false, \ - "unsupported threadblock shape (%dx%dx%d) and warp shape " \ - "(%dx%dx%d) and access_size (%d)", \ - threadblock_shape.m(), threadblock_shape.n(), \ - threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ - warp_shape.k(), access_size); - DISPATCH_KERNEL; - -#undef RUN_CUTLASS_WRAPPER -#undef DISPATCH_KERNEL_WITH_TILE_SHAPE -#undef DISPATCH_KERNEL -} -#endif - -#define INST(need_load_from_const_mem) \ - template void megdnn::cuda::cutlass_wrapper:: \ - do_conv_bias_int4_int4_implicit_gemm_imma_nhwc< \ - need_load_from_const_mem>( \ - const int8_t* d_src, const int8_t* d_filter, \ - const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ - int* workspace, const convolution::ConvParam& param, \ - uint32_t nonlinear_mode, float alpha, float beta, \ - float gamma, float scale, \ - const GemmCoord& threadblock_shape, \ - const GemmCoord& warp_shape, const int32_t access_size, \ - int stages, cudaStream_t stream); -INST(true); -INST(false); -#undef INST - -/* ====== cutlass kernel wrapper for uint4 x int4 nhwc layout ====== */ - -#if MEGDNN_TEGRA_X1 -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( - const uint8_t* /* d_src */, const int8_t* /* d_filter */, - const int32_t* /* d_bias */, const uint8_t* /* d_z */, - uint8_t* /* d_dst */, int* /* workspace */, - const convolution::ConvParam& /* param */, - uint32_t /* nonlinear_mode */, float /* alpha */, - float /* beta */, float /* gamma */, float /* delta */, - float /* theta */, float /* scale */, - uint8_t /* src_zero_point */, - const GemmCoord& /* threadblock_shape */, - const GemmCoord& /* warp_shape */, - const int32_t /* access_size */, int /* stages */, - cudaStream_t /* stream */) {} -#else -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( - const uint8_t* d_src, const int8_t* d_filter, - const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, - int* workspace, const convolution::ConvParam& param, - uint32_t nonlinear_mode, float alpha, float beta, float gamma, - float delta, float theta, float /* scale */, - uint8_t src_zero_point, const GemmCoord& threadblock_shape, - const GemmCoord& warp_shape, const int32_t access_size, - int stages, cudaStream_t stream) { - bool without_shared_load = - ((param.co % threadblock_shape.n() == 0) && - (threadblock_shape.n() == 32 || threadblock_shape.n() == 64)); - int out_elements_per_access = - without_shared_load ? threadblock_shape.n() / 4 : 8; - -#define RUN_CUTLASS_WRAPPER(stage_, access_size_, without_shared_load_) \ - using Convolution = cutlass::conv::device::Convolution< \ - cutlass::uint4b_t, cutlass::layout::TensorNHWC, cutlass::int4b_t, \ - cutlass::layout::TensorNCxHWx, ElementOutput, \ - cutlass::layout::TensorNHWC, int32_t, cutlass::layout::TensorNHWC, \ - int32_t, cutlass::conv::ConvType::kConvolution, \ - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ - ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ - cutlass::conv::threadblock:: \ - ConvolutionFpropTransThreadblockSwizzle, \ - stage_, access_size_, access_size_, NeedLoadFromConstMem, \ - cutlass::arch::OpMultiplyAddSaturate, \ - cutlass::conv::ImplicitGemmMode::GEMM_TN, without_shared_load_>; \ - typename Convolution::ConvolutionParameter conv_param( \ - param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ - param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ - param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ - return cutlass_convolution_wrapper( \ - reinterpret_cast(d_src), \ - reinterpret_cast(d_filter), d_bias, \ - reinterpret_cast(d_z), \ - reinterpret_cast(d_dst), workspace, \ - conv_param, epilogue, stream, {src_zero_point}); - -#define DISPATCH_KERNEL_WITH_TILE_SHAPE( \ - threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ - warp_k_, stage_, access_size_, out_elements_per_access_, \ - without_shared_load_) \ - if (threadblock_shape.m() == threadblock_m_ && \ - threadblock_shape.n() == threadblock_n_ && \ - threadblock_shape.k() == threadblock_k_ && \ - warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ - warp_shape.k() == warp_k_ && stages == stage_ && \ - access_size == access_size_ && \ - out_elements_per_access == out_elements_per_access_ && \ - without_shared_load == without_shared_load_) { \ - using ThreadBlockShape = \ - cutlass::gemm::GemmShape; \ - using WarpShape = cutlass::gemm::GemmShape; \ - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ - using ElementOutput = cutlass::uint4b_t; \ - using ElementAccumulator = int32_t; \ - using ElementBias = int32_t; \ - using ElementCompute = float; \ - using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; \ - switch (nonlinear_mode) { \ - case NonlineMode::IDENTITY: { \ - using EpilogueOp = cutlass::epilogue::thread:: \ - BiasAddLinearCombinationClamp< \ - ElementOutput, out_elements_per_access_, \ - ElementAccumulator, ElementBias, \ - ElementCompute>; \ - typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ - delta + theta}; \ - RUN_CUTLASS_WRAPPER(stage_, access_size_, \ - without_shared_load_); \ - } \ - case NonlineMode::RELU: { \ - using EpilogueOp = cutlass::epilogue::thread:: \ - BiasAddLinearCombinationReluClamp< \ - ElementOutput, out_elements_per_access_, \ - ElementAccumulator, ElementBias, \ - ElementCompute>; \ - typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ - 0, delta, theta}; \ - RUN_CUTLASS_WRAPPER(stage_, access_size_, \ - without_shared_load_); \ - } \ - default: \ - megdnn_assert( \ - false, \ - "unsupported nonlinear mode for conv bias operator"); \ - } \ - } -#define DISPATCH_KERNEL \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, false); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, false); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, false); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 8, false); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 8, false); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 8, false); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, true); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, true); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, true); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 16, true); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 16, true); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 16, true); \ - megdnn_assert(false, \ - "unsupported threadblock shape (%dx%dx%d) and warp shape " \ - "(%dx%dx%d) and access_size (%d)", \ - threadblock_shape.m(), threadblock_shape.n(), \ - threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ - warp_shape.k(), access_size); - - DISPATCH_KERNEL; - -#undef RUN_CUTLASS_WRAPPER -#undef DISPATCH_KERNEL_WITH_TILE_SHAPE -#undef DISPATCH_KERNEL -} -#endif - -#define INST(need_load_from_const_mem) \ - template void megdnn::cuda::cutlass_wrapper:: \ - do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc< \ - need_load_from_const_mem>( \ - const uint8_t* d_src, const int8_t* d_filter, \ - const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ - int* workspace, const convolution::ConvParam& param, \ - uint32_t nonlinear_mode, float alpha, float beta, \ - float gamma, float delta, float theta, float scale, \ - uint8_t src_zero_point, \ - const GemmCoord& threadblock_shape, \ - const GemmCoord& warp_shape, const int32_t access_size, \ - int stages, cudaStream_t stream); -INST(true); -INST(false); -#undef INST - -// vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper_int8.cu b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper_int8.cu deleted file mode 100644 index 208e080b..00000000 --- a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper_int8.cu +++ /dev/null @@ -1,804 +0,0 @@ -/** - * \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ -// ignore warning of cutlass -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#pragma GCC diagnostic ignored "-Wstrict-aliasing" - -#if !MEGDNN_TEGRA_X1 -#include "cutlass/convolution/device/convolution.h" -#endif -#include "src/common/opr_param_defs_enumv.cuh" -#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" -#pragma GCC diagnostic pop - -using namespace megdnn; -using namespace cuda; -using namespace cutlass_wrapper; - -/* ====== cutlass kernel wrapper for int8 nchw32 layout ====== */ - -#if MEGDNN_TEGRA_X1 -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( - const int8_t* /* d_src */, const int8_t* /* d_filter */, - const int32_t* /* d_bias */, const int8_t* /* d_z */, - int8_t* /* d_dst */, int* /* workspace */, - const convolution::ConvParam& /* param */, - uint32_t /* nonlinear_mode */, float /* alpha */, - float /* beta */, float /* gamma */, float /* scale */, - const GemmCoord& /* threadblock_shape */, - const GemmCoord& /* warp_shape */, int /* stages */, - cudaStream_t /* stream */) {} -#else -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( - const int8_t* d_src, const int8_t* d_filter, - const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, - int* workspace, const convolution::ConvParam& param, - uint32_t nonlinear_mode, float alpha, float beta, float gamma, - float scale, const GemmCoord& threadblock_shape, - const GemmCoord& warp_shape, int stages, cudaStream_t stream) { -#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ - threadblock_k_, warp_m_, warp_n_, \ - warp_k_, stage_) \ - if (threadblock_shape.m() == threadblock_m_ && \ - threadblock_shape.n() == threadblock_n_ && \ - threadblock_shape.k() == threadblock_k_ && \ - warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ - warp_shape.k() == warp_k_ && stages == stage_) { \ - using ThreadBlockShape = \ - cutlass::gemm::GemmShape; \ - using WarpShape = cutlass::gemm::GemmShape; \ - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; \ - using Convolution = cutlass::conv::device::Convolution< \ - int8_t, cutlass::layout::TensorNCxHWx<32>, int8_t, \ - cutlass::layout::TensorCxRSKx<32>, ElementOutput, \ - cutlass::layout::TensorNCxHWx<32>, int32_t, \ - cutlass::layout::TensorNCxHWx<32>, int32_t, \ - cutlass::conv::ConvType::kConvolution, \ - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ - ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ - cutlass::conv::threadblock:: \ - ConvolutionFpropTransThreadblockSwizzle, \ - stage_, 16, 16, NeedLoadFromConstMem, \ - cutlass::arch::OpMultiplyAddSaturate, \ - cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ - typename Convolution::ConvolutionParameter conv_param( \ - param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ - param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ - param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ - return cutlass_convolution_wrapper( \ - d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ - epilogue, stream); \ - } -#define DISPATCH_KERNEL \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64, 2); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64, 2); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64, 2); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64, 2); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64, 2); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 1); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 1); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 32, 64, 32, 1); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 1); \ - megdnn_assert(false, \ - "unsupported threadblock shape (%dx%dx%d) and warp shape " \ - "(%dx%dx%d)", \ - threadblock_shape.m(), threadblock_shape.n(), \ - threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ - warp_shape.k()); - using ElementOutput = int8_t; - using ElementAccumulator = int32_t; - using ElementBias = int32_t; - using ElementCompute = float; - using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; - switch (nonlinear_mode) { - case NonlineMode::IDENTITY: { - using EpilogueOp = - cutlass::epilogue::thread::BiasAddLinearCombinationClamp< - ElementOutput, 8, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma}; - DISPATCH_KERNEL; - } - case NonlineMode::RELU: { - using EpilogueOp = cutlass::epilogue::thread:: - BiasAddLinearCombinationReluClamp< - ElementOutput, 8, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; - DISPATCH_KERNEL; - } - case NonlineMode::H_SWISH: { - using EpilogueOp = cutlass::epilogue::thread:: - BiasAddLinearCombinationHSwishClamp< - ElementOutput, 8, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; - DISPATCH_KERNEL; - } - default: - megdnn_assert(false, - "unsupported nonlinear mode for conv bias operator"); - } -#undef DISPATCH_KERNEL_WITH_TILE_SHAPE -#undef DISPATCH_KERNEL -} -#endif - -#define INST(need_load_from_const_mem) \ - template void megdnn::cuda::cutlass_wrapper:: \ - do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32< \ - need_load_from_const_mem>( \ - const int8_t* d_src, const int8_t* d_filter, \ - const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ - int* workspace, const convolution::ConvParam& param, \ - uint32_t nonlinear_mode, float alpha, float beta, \ - float gamma, float scale, \ - const GemmCoord& threadblock_shape, \ - const GemmCoord& warp_shape, int stages, \ - cudaStream_t stream); -INST(true); -INST(false); -#undef INST - -/* ===== cutlass kernel wrapper for int8 nchw32 layout and nchw4 output ===== */ - -#if MEGDNN_TEGRA_X1 -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( - const int8_t* /* d_src */, const int8_t* /* d_filter */, - const int32_t* /* d_bias */, const int8_t* /* d_z */, - int8_t* /* d_dst */, int* /* workspace */, - const convolution::ConvParam& /* param */, - uint32_t /* nonlinear_mode */, float /* alpha */, - float /* beta */, float /* gamma */, float /* scale */, - const GemmCoord& /* threadblock_shape */, - const GemmCoord& /* warp_shape */, int /* stages */, - cudaStream_t /* stream */) {} -#else -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( - const int8_t* d_src, const int8_t* d_filter, - const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, - int* workspace, const convolution::ConvParam& param, - uint32_t nonlinear_mode, float alpha, float beta, float gamma, - float scale, const GemmCoord& threadblock_shape, - const GemmCoord& warp_shape, int stages, cudaStream_t stream) { -#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ - threadblock_k_, warp_m_, warp_n_, \ - warp_k_, stage_) \ - if (threadblock_shape.m() == threadblock_m_ && \ - threadblock_shape.n() == threadblock_n_ && \ - threadblock_shape.k() == threadblock_k_ && \ - warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ - warp_shape.k() == warp_k_ && stages == stage_) { \ - using ThreadBlockShape = \ - cutlass::gemm::GemmShape; \ - using WarpShape = cutlass::gemm::GemmShape; \ - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; \ - using Convolution = cutlass::conv::device::Convolution< \ - int8_t, cutlass::layout::TensorNCxHWx<32>, int8_t, \ - cutlass::layout::TensorCxRSKx<32>, ElementOutput, \ - cutlass::layout::TensorNCxHWx<4>, int32_t, \ - cutlass::layout::TensorNCxHWx<4>, int32_t, \ - cutlass::conv::ConvType::kConvolution, \ - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ - ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ - cutlass::conv::threadblock:: \ - ConvolutionFpropNCxHWxThreadblockSwizzle, \ - stage_, 16, 16, NeedLoadFromConstMem>; \ - typename Convolution::ConvolutionParameter conv_param( \ - param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ - param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ - param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ - return cutlass_convolution_wrapper( \ - d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ - epilogue, stream); \ - } -#define DISPATCH_KERNEL \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64, 2); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64, 2); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64, 2); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64, 2); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64, 2); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 1); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 1); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 32, 64, 32, 1); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 1); \ - megdnn_assert(false, \ - "unsupported threadblock shape (%dx%dx%d) and warp shape " \ - "(%dx%dx%d)", \ - threadblock_shape.m(), threadblock_shape.n(), \ - threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ - warp_shape.k()); - using ElementOutput = int8_t; - using ElementAccumulator = int32_t; - using ElementBias = int32_t; - using ElementCompute = float; - using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; - switch (nonlinear_mode) { - case NonlineMode::IDENTITY: { - using EpilogueOp = - cutlass::epilogue::thread::BiasAddLinearCombinationClamp< - ElementOutput, 4, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma}; - DISPATCH_KERNEL; - } - case NonlineMode::RELU: { - using EpilogueOp = cutlass::epilogue::thread:: - BiasAddLinearCombinationReluClamp< - ElementOutput, 4, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; - DISPATCH_KERNEL; - } - case NonlineMode::H_SWISH: { - using EpilogueOp = cutlass::epilogue::thread:: - BiasAddLinearCombinationHSwishClamp< - ElementOutput, 4, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; - DISPATCH_KERNEL; - } - default: - megdnn_assert(false, - "unsupported nonlinear mode for conv bias operator"); - } -#undef DISPATCH_KERNEL_WITH_TILE_SHAPE -#undef DISPATCH_KERNEL -} -#endif - -#define INST(need_load_from_const_mem) \ - template void megdnn::cuda::cutlass_wrapper:: \ - do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4< \ - need_load_from_const_mem>( \ - const int8_t* d_src, const int8_t* d_filter, \ - const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ - int* workspace, const convolution::ConvParam& param, \ - uint32_t nonlinear_mode, float alpha, float beta, \ - float gamma, float scale, \ - const GemmCoord& threadblock_shape, \ - const GemmCoord& warp_shape, int stages, \ - cudaStream_t stream); -INST(true); -INST(false); -#undef INST - -/* ====== cutlass kernel wrapper for int8 nchw4 layout ====== */ - -#if MEGDNN_TEGRA_X1 -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4( - const int8_t* /* d_src */, const int8_t* /* d_filter */, - const int32_t* /* d_bias */, const int8_t* /* d_z */, - int8_t* /* d_dst */, int* /* workspace */, - const convolution::ConvParam& /* param */, - uint32_t /* nonlinear_mode */, float /* alpha */, - float /* beta */, float /* gamma */, float /* scale */, - const GemmCoord& /* threadblock_shape */, - const GemmCoord& /* warp_shape */, int /* stages */, - cudaStream_t /* stream */) {} -#else -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4( - const int8_t* d_src, const int8_t* d_filter, - const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, - int* workspace, const convolution::ConvParam& param, - uint32_t nonlinear_mode, float alpha, float beta, float gamma, - float scale, const GemmCoord& threadblock_shape, - const GemmCoord& warp_shape, int stages, cudaStream_t stream) { -#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ - threadblock_k_, warp_m_, warp_n_, \ - warp_k_, stage_, aligned_) \ - if (threadblock_shape.m() == threadblock_m_ && \ - threadblock_shape.n() == threadblock_n_ && \ - threadblock_shape.k() == threadblock_k_ && \ - warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ - warp_shape.k() == warp_k_ && stages == stage_) { \ - using ThreadBlockShape = \ - cutlass::gemm::GemmShape; \ - using WarpShape = cutlass::gemm::GemmShape; \ - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ - using Convolution = cutlass::conv::device::Convolution< \ - int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ - cutlass::layout::TensorCxRSKx<4>, ElementOutput, \ - cutlass::layout::TensorNCxHWx<4>, int32_t, \ - cutlass::layout::TensorNCxHWx<4>, int32_t, \ - cutlass::conv::ConvType::kConvolution, \ - cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \ - ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ - cutlass::conv::threadblock:: \ - ConvolutionFpropNCxHWxThreadblockSwizzle, \ - stage_, 4, aligned_, NeedLoadFromConstMem, \ - cutlass::arch::OpMultiplyAdd>; \ - typename Convolution::ConvolutionParameter conv_param( \ - param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ - param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ - param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ - return cutlass_convolution_wrapper( \ - d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ - epilogue, stream); \ - } -#define DISPATCH_KERNEL \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ - megdnn_assert(false, \ - "unsupported threadblock shape (%dx%dx%d) and warp shape " \ - "(%dx%dx%d)", \ - threadblock_shape.m(), threadblock_shape.n(), \ - threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ - warp_shape.k()); - using ElementOutput = int8_t; - using ElementAccumulator = int32_t; - using ElementBias = int32_t; - using ElementCompute = float; - using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; - switch (nonlinear_mode) { - case NonlineMode::IDENTITY: { - using EpilogueOp = - cutlass::epilogue::thread::BiasAddLinearCombinationClamp< - ElementOutput, 4, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma}; - DISPATCH_KERNEL; - } - case NonlineMode::RELU: { - using EpilogueOp = cutlass::epilogue::thread:: - BiasAddLinearCombinationReluClamp< - ElementOutput, 4, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; - DISPATCH_KERNEL; - } - case NonlineMode::H_SWISH: { - using EpilogueOp = cutlass::epilogue::thread:: - BiasAddLinearCombinationHSwishClamp< - ElementOutput, 4, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; - DISPATCH_KERNEL; - } - default: - megdnn_assert(false, - "unsupported nonlinear mode for conv bias operator"); - } -#undef DISPATCH_KERNEL_WITH_TILE_SHAPE -#undef DISPATCH_KERNEL -} -#endif - -#define INST(need_load_from_const_mem) \ - template void megdnn::cuda::cutlass_wrapper:: \ - do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4< \ - need_load_from_const_mem>( \ - const int8_t* d_src, const int8_t* d_filter, \ - const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ - int* workspace, const convolution::ConvParam& param, \ - uint32_t nonlinear_mode, float alpha, float beta, \ - float gamma, float scale, \ - const GemmCoord& threadblock_shape, \ - const GemmCoord& warp_shape, int stages, \ - cudaStream_t stream); -INST(true); -INST(false); -#undef INST - -/* ====== cutlass kernel wrapper for int8 nchw4 layout and nchw output ====== */ - -#if MEGDNN_TEGRA_X1 -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw( - const int8_t* /* d_src */, const int8_t* /* d_filter */, - const float* /* d_bias */, const float* /* d_z */, - float* /* d_dst */, int* /* workspace */, - const convolution::ConvParam& /* param */, - uint32_t /* nonlinear_mode */, float /* alpha */, - float /* beta */, float /* gamma */, float /* scale */, - const GemmCoord& /* threadblock_shape */, - const GemmCoord& /* warp_shape */, int /* stages */, - cudaStream_t /* stream */) {} -#else -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw( - const int8_t* d_src, const int8_t* d_filter, - const float* d_bias, const float* d_z, float* d_dst, - int* workspace, const convolution::ConvParam& param, - uint32_t nonlinear_mode, float alpha, float beta, float gamma, - float scale, const GemmCoord& threadblock_shape, - const GemmCoord& warp_shape, int stages, cudaStream_t stream) { -#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ - threadblock_k_, warp_m_, warp_n_, \ - warp_k_, stages_, aligned_) \ - if (threadblock_shape.m() == threadblock_m_ && \ - threadblock_shape.n() == threadblock_n_ && \ - threadblock_shape.k() == threadblock_k_ && \ - warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ - warp_shape.k() == warp_k_ && stages == stages_) { \ - using ThreadBlockShape = \ - cutlass::gemm::GemmShape; \ - using WarpShape = cutlass::gemm::GemmShape; \ - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ - using Convolution = cutlass::conv::device::Convolution< \ - int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ - cutlass::layout::TensorCxRSKx<4>, ElementOutput, \ - cutlass::layout::TensorNCHW, float, \ - cutlass::layout::TensorNCHW, int32_t, \ - cutlass::conv::ConvType::kConvolution, \ - cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \ - ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ - cutlass::conv::threadblock:: \ - ConvolutionFpropNCxHWxThreadblockSwizzle, \ - stages_, 4, aligned_, NeedLoadFromConstMem, \ - cutlass::arch::OpMultiplyAdd>; \ - typename Convolution::ConvolutionParameter conv_param( \ - param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ - param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ - param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ - return cutlass_convolution_wrapper( \ - d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ - epilogue, stream); \ - } -#define DISPATCH_KERNEL \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ - megdnn_assert(false, \ - "unsupported threadblock shape (%dx%dx%d) and warp shape " \ - "(%dx%dx%d)", \ - threadblock_shape.m(), threadblock_shape.n(), \ - threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ - warp_shape.k()); - using ElementOutput = float; - using ElementAccumulator = int32_t; - using ElementBias = float; - using ElementCompute = float; - using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; - switch (nonlinear_mode) { - case NonlineMode::IDENTITY: { - using EpilogueOp = - cutlass::epilogue::thread::BiasAddLinearCombination< - ElementOutput, 1, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma}; - DISPATCH_KERNEL; - } - case NonlineMode::RELU: { - using EpilogueOp = - cutlass::epilogue::thread::BiasAddLinearCombinationRelu< - ElementOutput, 1, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; - DISPATCH_KERNEL; - } - case NonlineMode::H_SWISH: { - using EpilogueOp = - cutlass::epilogue::thread::BiasAddLinearCombinationHSwish< - ElementOutput, 1, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; - DISPATCH_KERNEL; - } - default: - megdnn_assert(false, - "unsupported nonlinear mode for conv bias operator"); - } -#undef DISPATCH_KERNEL_WITH_TILE_SHAPE -#undef DISPATCH_KERNEL -} -#endif - -#define INST(need_load_from_const_mem) \ - template void megdnn::cuda::cutlass_wrapper:: \ - do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw< \ - need_load_from_const_mem>( \ - const int8_t* d_src, const int8_t* d_filter, \ - const float* d_bias, const float* d_z, float* d_dst, \ - int* workspace, const convolution::ConvParam& param, \ - uint32_t nonlinear_mode, float alpha, float beta, \ - float gamma, float scale, \ - const GemmCoord& threadblock_shape, \ - const GemmCoord& warp_shape, int stages, \ - cudaStream_t stream); -INST(true); -INST(false); -#undef INST - -/* ===== cutlass kernel wrapper for int8 nchw4 layout and nchw32 output ===== */ - -#if MEGDNN_TEGRA_X1 -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32( - const int8_t* /* d_src */, const int8_t* /* d_filter */, - const int32_t* /* d_bias */, const int8_t* /* d_z */, - int8_t* /* d_dst */, int* /* workspace */, - const convolution::ConvParam& /* param */, - uint32_t /* nonlinear_mode */, float /* alpha */, - float /* beta */, float /* gamma */, float /* scale */, - const GemmCoord& /* threadblock_shape */, - const GemmCoord& /* warp_shape */, int /* stages */, - cudaStream_t /* stream */) {} -#else -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32( - const int8_t* d_src, const int8_t* d_filter, - const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, - int* workspace, const convolution::ConvParam& param, - uint32_t nonlinear_mode, float alpha, float beta, float gamma, - float scale, const GemmCoord& threadblock_shape, - const GemmCoord& warp_shape, int stages, cudaStream_t stream) { -#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ - threadblock_k_, warp_m_, warp_n_, \ - warp_k_, stages_, aligned_) \ - if (threadblock_shape.m() == threadblock_m_ && \ - threadblock_shape.n() == threadblock_n_ && \ - threadblock_shape.k() == threadblock_k_ && \ - warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ - warp_shape.k() == warp_k_ && stages == stages_) { \ - using ThreadBlockShape = \ - cutlass::gemm::GemmShape; \ - using WarpShape = cutlass::gemm::GemmShape; \ - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ - using Convolution = cutlass::conv::device::Convolution< \ - int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ - cutlass::layout::TensorCxRSKx<4>, ElementOutput, \ - cutlass::layout::TensorNCxHWx<32>, int32_t, \ - cutlass::layout::TensorNCxHWx<32>, int32_t, \ - cutlass::conv::ConvType::kConvolution, \ - cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \ - ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ - cutlass::conv::threadblock:: \ - ConvolutionFpropNCxHWxThreadblockSwizzle, \ - stages_, 4, aligned_, NeedLoadFromConstMem, \ - cutlass::arch::OpMultiplyAdd>; \ - typename Convolution::ConvolutionParameter conv_param( \ - param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ - param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ - param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ - return cutlass_convolution_wrapper( \ - d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ - epilogue, stream); \ - } -#define DISPATCH_KERNEL \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ - megdnn_assert(false, \ - "unsupported threadblock shape (%dx%dx%d) and warp shape " \ - "(%dx%dx%d)", \ - threadblock_shape.m(), threadblock_shape.n(), \ - threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ - warp_shape.k()); - using ElementOutput = int8_t; - using ElementAccumulator = int32_t; - using ElementBias = int32_t; - using ElementCompute = float; - using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; - switch (nonlinear_mode) { - case NonlineMode::IDENTITY: { - using EpilogueOp = - cutlass::epilogue::thread::BiasAddLinearCombinationClamp< - ElementOutput, 4, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma}; - DISPATCH_KERNEL; - } - case NonlineMode::RELU: { - using EpilogueOp = cutlass::epilogue::thread:: - BiasAddLinearCombinationReluClamp< - ElementOutput, 4, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; - DISPATCH_KERNEL; - } - case NonlineMode::H_SWISH: { - using EpilogueOp = cutlass::epilogue::thread:: - BiasAddLinearCombinationHSwishClamp< - ElementOutput, 4, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; - DISPATCH_KERNEL; - } - default: - megdnn_assert(false, - "unsupported nonlinear mode for conv bias operator"); - } -#undef DISPATCH_KERNEL_WITH_TILE_SHAPE -#undef DISPATCH_KERNEL -} -#endif - -#define INST(need_load_from_const_mem) \ - template void megdnn::cuda::cutlass_wrapper:: \ - do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< \ - need_load_from_const_mem>( \ - const int8_t* d_src, const int8_t* d_filter, \ - const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ - int* workspace, const convolution::ConvParam& param, \ - uint32_t nonlinear_mode, float alpha, float beta, \ - float gamma, float scale, \ - const GemmCoord& threadblock_shape, \ - const GemmCoord& warp_shape, int stages, \ - cudaStream_t stream); -INST(true); -INST(false); -#undef INST - -/* ===== cutlass kernel wrapper for nchw4 layout and nhwc output ===== */ -#if MEGDNN_TEGRA_X1 -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( - const int8_t* /* d_src */, const int8_t* /* d_filter */, - const int32_t* /* d_bias */, const int8_t* /* d_z */, - int8_t* /* d_dst */, int* /* workspace */, - const convolution::ConvParam& /* param */, - uint32_t /* nonlinear_mode */, float /* alpha */, - float /* beta */, float /* gamma */, float /* delta */, - float /* theta */, float /* scale */, - const GemmCoord& /* threadblock_shape */, - const GemmCoord& /* warp_shape */, int /* stages */, - cudaStream_t /* stream */) {} -#else -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( - const int8_t* d_src, const int8_t* d_filter, - const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, - int* workspace, const convolution::ConvParam& param, - uint32_t nonlinear_mode, float alpha, float beta, float gamma, - float delta, float theta, float scale, - const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, - int stages, cudaStream_t stream) { -#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ - threadblock_k_, warp_m_, warp_n_, \ - warp_k_, stages_, aligned_) \ - if (threadblock_shape.m() == threadblock_m_ && \ - threadblock_shape.n() == threadblock_n_ && \ - threadblock_shape.k() == threadblock_k_ && \ - warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ - warp_shape.k() == warp_k_ && stages == stages_) { \ - using ThreadBlockShape = \ - cutlass::gemm::GemmShape; \ - using WarpShape = cutlass::gemm::GemmShape; \ - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ - using Convolution = cutlass::conv::device::Convolution< \ - int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ - cutlass::layout::TensorCxRSKx<4>, ElementOutput, \ - cutlass::layout::TensorNHWC, int32_t, \ - cutlass::layout::TensorNHWC, int32_t, \ - cutlass::conv::ConvType::kConvolution, \ - cutlass::arch::OpClassSimt, cutlass::arch::Sm75, \ - ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ - cutlass::conv::threadblock:: \ - ConvolutionFpropNCxHWxThreadblockSwizzle, \ - stages_, 4, aligned_, true, cutlass::arch::OpMultiplyAdd>; \ - typename Convolution::ConvolutionParameter conv_param( \ - param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ - param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ - param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ - return cutlass_convolution_wrapper( \ - d_src, d_filter, d_bias, \ - reinterpret_cast(d_z), \ - reinterpret_cast(d_dst), workspace, \ - conv_param, epilogue, stream); \ - } -#define DISPATCH_KERNEL \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ - megdnn_assert(false, \ - "unsupported threadblock shape (%dx%dx%d) and warp shape " \ - "(%dx%dx%d)", \ - threadblock_shape.m(), threadblock_shape.n(), \ - threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ - warp_shape.k()); - using ElementOutput = cutlass::integer_subbyte<4, signedness>; - using ElementAccumulator = int32_t; - using ElementBias = int32_t; - using ElementCompute = float; - using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; - switch (nonlinear_mode) { - case NonlineMode::IDENTITY: { - using EpilogueOp = - cutlass::epilogue::thread::BiasAddLinearCombinationClamp< - ElementOutput, 8, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, - delta + theta}; - DISPATCH_KERNEL; - } - case NonlineMode::RELU: { - using EpilogueOp = cutlass::epilogue::thread:: - BiasAddLinearCombinationReluClamp< - ElementOutput, 8, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, - 0, delta, theta}; - DISPATCH_KERNEL; - } - case NonlineMode::H_SWISH: { - using EpilogueOp = cutlass::epilogue::thread:: - BiasAddLinearCombinationHSwishClamp< - ElementOutput, 8, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, - scale, delta, theta}; - DISPATCH_KERNEL; - } - default: - megdnn_assert(false, - "unsupported nonlinear mode for conv bias operator"); - } -#undef DISPATCH_KERNEL_WITH_TILE_SHAPE -#undef DISPATCH_KERNEL -} -#endif - -#define INST(signedness) \ - template void megdnn::cuda::cutlass_wrapper:: \ - do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( \ - const int8_t* d_src, const int8_t* d_filter, \ - const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ - int* workspace, const convolution::ConvParam& param, \ - uint32_t nonlinear_mode, float alpha, float beta, \ - float gamma, float delta, float theta, float scale, \ - const GemmCoord& threadblock_shape, \ - const GemmCoord& warp_shape, int stages, \ - cudaStream_t stream); -INST(true); -INST(false); -#undef INST - -// vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl b/dnn/src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl deleted file mode 100644 index 6d1582bd..00000000 --- a/dnn/src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl +++ /dev/null @@ -1,65 +0,0 @@ -/** - * \file - * dnn/src/cuda/conv_bias/int8/implicit_gemm_conv_bias_cutlass_wrapper.cuinl - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ -#include "cutlass/convolution/device/convolution.h" -#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" - -using namespace megdnn; -using namespace cuda; -using namespace cutlass_wrapper; - -template -void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper( - const typename Convolution::ElementSrc* d_src, - const typename Convolution::ElementFilter* d_filter, - const typename Convolution::ElementBias* d_bias, - const typename Convolution::ElementDst* d_z, - typename Convolution::ElementDst* d_dst, int* workspace, - typename Convolution::ConvolutionParameter const& conv_param, - typename Convolution::EpilogueOutputOp::Params const& epilogue, - cudaStream_t stream, typename Convolution::ExtraParam extra_param) { - typename Convolution::TensorRefSrc tensor_src{ - const_cast(d_src), - Convolution::LayoutSrc::packed( - {conv_param.N, conv_param.H, conv_param.W, conv_param.C})}; - typename Convolution::TensorRefFilter tensor_filter{ - const_cast(d_filter), - Convolution::LayoutFilter::packed( - {conv_param.K, conv_param.R, conv_param.S, conv_param.C})}; - typename Convolution::TensorRefBias tensor_bias{ - const_cast(d_bias), - Convolution::LayoutBias::packed({1, 1, 1, conv_param.K})}; - typename Convolution::TensorRefDst tensor_z{ - const_cast(d_z), - Convolution::LayoutDst::packed( - {conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; - typename Convolution::TensorRefDst tensor_dst{ - d_dst, - Convolution::LayoutDst::packed( - {conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; - typename Convolution::Arguments arguments{conv_param, - tensor_src.non_const_ref(), - tensor_filter.non_const_ref(), - tensor_bias.non_const_ref(), - tensor_z.non_const_ref(), - tensor_dst.non_const_ref(), - epilogue, - {}, - {}, - extra_param}; - Convolution conv_op; - cutlass_check(conv_op.initialize(arguments, workspace)); - cutlass_check(conv_op(stream)); - after_kernel_launch(); -} - -// vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nchw64_imma.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nchw64_imma.cpp index 47a65935..bcd8c655 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nchw64_imma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nchw64_imma.cpp @@ -10,8 +10,7 @@ * implied. */ -#include "./algo.h" -#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" +#include "src/cuda/conv_bias/algo.h" using namespace megdnn; using namespace cuda; @@ -81,29 +80,6 @@ ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::get_constants( return {alpha, beta, gamma, delta, theta}; } - -void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::do_exec( - const ExecArgs& args, void* filter_ptr, void* bias_ptr, void* z_ptr, - ConvParam kern_param, uint32_t nonlinear_mode, float alpha, float beta, - float gamma, float delta, float theta, cudaStream_t stream) const { - float dst_scale = args.dst_layout->dtype.param().scale; - - cutlass_wrapper::GemmCoord threadblock_shape{m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k}; - - cutlass_wrapper::GemmCoord warp_shape{ - m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k}; - - cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< - true>(reinterpret_cast(args.src_tensor->raw_ptr), - reinterpret_cast(filter_ptr), - reinterpret_cast(bias_ptr), - reinterpret_cast(z_ptr), - reinterpret_cast(args.dst_tensor->raw_ptr), nullptr, - kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, - threadblock_shape, warp_shape, m_algo_param.stage, stream); -} #endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nhwc_imma.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nhwc_imma.cpp index 30797287..cbe3bc5d 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nhwc_imma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nhwc_imma.cpp @@ -10,8 +10,7 @@ * implied. */ -#include "./algo.h" -#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" +#include "src/cuda/conv_bias/algo.h" using namespace megdnn; using namespace cuda; @@ -81,42 +80,6 @@ ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::get_constants( return {alpha, beta, gamma, delta, theta}; } - -void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::do_exec( - const ExecArgs& args, void* filter_ptr, void* bias_ptr, void* z_ptr, - ConvParam kern_param, uint32_t nonlinear_mode, float alpha, float beta, - float gamma, float delta, float theta, cudaStream_t stream) const { - float dst_scale = args.dst_layout->dtype.param().scale; - - cutlass_wrapper::GemmCoord threadblock_shape{m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k}; - - cutlass_wrapper::GemmCoord warp_shape{ - m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k}; - - if (kern_param.fh == 1 && kern_param.fw == 1) { - cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( - reinterpret_cast(args.src_tensor->raw_ptr), - reinterpret_cast(filter_ptr), - reinterpret_cast(bias_ptr), - reinterpret_cast(z_ptr), - reinterpret_cast(args.dst_tensor->raw_ptr), nullptr, - kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, - threadblock_shape, warp_shape, m_algo_param.access_size, - m_algo_param.stage, stream); - } else { - cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( - reinterpret_cast(args.src_tensor->raw_ptr), - reinterpret_cast(filter_ptr), - reinterpret_cast(bias_ptr), - reinterpret_cast(z_ptr), - reinterpret_cast(args.dst_tensor->raw_ptr), nullptr, - kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, - threadblock_shape, warp_shape, m_algo_param.access_size, - m_algo_param.stage, stream); - } -} #endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int4_nchw64_imma_base.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int4_nchw64_imma_base.cpp index eeb01c62..5f4f11ef 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int4_nchw64_imma_base.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int4_nchw64_imma_base.cpp @@ -10,10 +10,9 @@ * implied. */ -#include "./algo.h" #include "src/common/conv_bias.h" +#include "src/cuda/conv_bias/algo.h" #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" -#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" #include "src/cuda/conv_bias/reduce_filter.cuh" #include "src/cuda/convolution_helper/parameter.cuh" #include "src/cuda/utils.h" @@ -102,22 +101,40 @@ void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::exec( if (args.z_layout->ndim > 0) z_ptr = args.z_tensor->raw_ptr; + // \note these constants of cutlass epilogue will be passed to method + // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, + // a different dtype here results in undefined epilogue behaviors float alpha, beta, gamma, delta, theta; + std::tie(alpha, beta, gamma, delta, theta) = get_constants(args); + float dst_scale = 0.f; + float threshold = 0.f; + uint8_t src_zero = 0; + bool load_from_const = !(fh == 1 && fw == 1); + bool without_shared_load = true; + + if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { + dst_scale = + args.dst_layout->dtype.param().scale; + src_zero = args.src_layout->dtype.param() + .zero_point; + } else { // DTypeEnum::QuantizedS4 + dst_scale = args.dst_layout->dtype.param().scale; + } - ConvParam kern_param; - kern_param.n = n, kern_param.co = co, kern_param.ci = ci, - kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, - kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, - kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, - kern_param.fw = fw; + cudaStream_t stream = cuda_stream(args.opr->handle()); - uint32_t nonlinear_mode = static_cast(param.nonlineMode); + const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop, + ConvType::kConvolution, + load_from_const, without_shared_load); - cudaStream_t stream = cuda_stream(args.opr->handle()); + execute_cutlass_conv_op(op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr, + z_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, + ci, co, fh, fw, ho, wo, ph, pw, sh, sw, dh, dw, + &alpha, &beta, &gamma, &delta, &theta, &threshold, + &dst_scale, stream, &src_zero); - do_exec(args, filter_ptr, bias_ptr, z_ptr, kern_param, nonlinear_mode, - alpha, beta, gamma, delta, theta, stream); + after_kernel_launch(); } std::string ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::to_string( diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int4_nhwc_imma_base.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int4_nhwc_imma_base.cpp index 9cb11de6..353e26d9 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int4_nhwc_imma_base.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int4_nhwc_imma_base.cpp @@ -10,10 +10,9 @@ * implied. */ -#include "./algo.h" #include "src/common/conv_bias.h" +#include "src/cuda/conv_bias/algo.h" #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" -#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" #include "src/cuda/conv_bias/reduce_filter.cuh" #include "src/cuda/convolution_helper/parameter.cuh" #include "src/cuda/utils.h" @@ -109,22 +108,43 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::exec( if (args.z_layout->ndim > 0) z_ptr = args.z_tensor->raw_ptr; + // \note these constants of cutlass epilogue will be passed to method + // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, + // a different dtype here results in undefined epilogue behaviors float alpha, beta, gamma, delta, theta; + std::tie(alpha, beta, gamma, delta, theta) = get_constants(args); + float dst_scale = 0.f; + float threshold = 0.f; + uint8_t src_zero = 0; + bool load_from_const = !(fh == 1 && fw == 1); + + bool without_shared_load = ((co % m_algo_param.threadblock_n == 0) && + (m_algo_param.threadblock_n == 32 || + m_algo_param.threadblock_n == 64)); + + if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { + dst_scale = + args.dst_layout->dtype.param().scale; + src_zero = args.src_layout->dtype.param() + .zero_point; + } else { // DTypeEnum::QuantizedS4 + dst_scale = args.dst_layout->dtype.param().scale; + } - ConvParam kern_param; - kern_param.n = n, kern_param.co = co, kern_param.ci = ci, - kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, - kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, - kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, - kern_param.fw = fw; + cudaStream_t stream = cuda_stream(args.opr->handle()); - uint32_t nonlinear_mode = static_cast(param.nonlineMode); + const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop, + ConvType::kConvolution, + load_from_const, without_shared_load); - cudaStream_t stream = cuda_stream(args.opr->handle()); + execute_cutlass_conv_op(op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr, + z_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, + ci, co, fh, fw, ho, wo, ph, pw, sh, sw, dh, dw, + &alpha, &beta, &gamma, &delta, &theta, &threshold, + &dst_scale, stream, &src_zero); - do_exec(args, filter_ptr, bias_ptr, z_ptr, kern_param, nonlinear_mode, - alpha, beta, gamma, delta, theta, stream); + after_kernel_launch(); } std::string ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::to_string( diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw32_imma.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw32_imma.cpp index f6654580..4eda5cd7 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw32_imma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw32_imma.cpp @@ -10,12 +10,11 @@ * implied. */ -#include "./algo.h" +#include "src/common/conv_bias.h" +#include "src/cuda/conv_bias/algo.h" #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" -#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" #include "src/cuda/convolution_helper/parameter.cuh" #include "src/cuda/utils.h" -#include "src/common/conv_bias.h" using namespace megdnn; using namespace cuda; @@ -38,8 +37,7 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::is_available( bool available = true; auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - if (!check_bias_share_in_channel(*(args.bias_layout), - param.format)) + if (!check_bias_share_in_channel(*(args.bias_layout), param.format)) return false; if (param.format != Format::NCHW32 && param.format != Format::NCHW32_NCHW4) return false; @@ -137,19 +135,16 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( args.preprocessed_filter->tensors[0].raw_ptr); } - ConvParam kern_param; - kern_param.n = n, kern_param.co = co, kern_param.ci = ci, - kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, - kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, - kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, - kern_param.fw = fw; - float src_scale = args.src_layout->dtype.param().scale, filter_scale = args.filter_layout->dtype.param().scale, bias_scale = args.bias_layout->dtype.param().scale, dst_scale = args.dst_layout->dtype.param().scale; + + // \note these constants of cutlass epilogue will be passed to method + // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, + // a different dtype here results in undefined epilogue behaviors float alpha = src_scale * filter_scale / dst_scale, beta = bias_scale / dst_scale; int8_t* z_dev_ptr = nullptr; @@ -159,80 +154,20 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( float z_scale = args.z_layout->dtype.param().scale; gamma = z_scale / dst_scale; } - uint32_t nonlinear_mode = static_cast(param.nonlineMode); - if (fh == 1 && fw == 1) { - if (param.format == Format::NCHW32) { - cutlass_wrapper::do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32< - false>( - args.src_tensor->compatible_ptr(), filter_ptr, - args.bias_tensor->compatible_ptr(), z_dev_ptr, - args.dst_tensor->compatible_ptr(), nullptr, - kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, - cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k}, - cutlass_wrapper::GemmCoord{m_algo_param.warp_m, - m_algo_param.warp_n, - m_algo_param.warp_k}, - m_algo_param.stage, stream); - } else { - megdnn_assert(param.format == Format::NCHW32_NCHW4); - cutlass_wrapper:: - do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4< - false>( - args.src_tensor->compatible_ptr(), - filter_ptr, - args.bias_tensor->compatible_ptr(), - z_dev_ptr, - args.dst_tensor->compatible_ptr(), nullptr, - kern_param, nonlinear_mode, alpha, beta, gamma, - dst_scale, - cutlass_wrapper::GemmCoord{ - m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k}, - cutlass_wrapper::GemmCoord{m_algo_param.warp_m, - m_algo_param.warp_n, - m_algo_param.warp_k}, - m_algo_param.stage, stream); - } - } else { - if (param.format == Format::NCHW32) { - cutlass_wrapper::do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32< - true>( - args.src_tensor->compatible_ptr(), filter_ptr, - args.bias_tensor->compatible_ptr(), z_dev_ptr, - args.dst_tensor->compatible_ptr(), nullptr, - kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, - cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k}, - cutlass_wrapper::GemmCoord{m_algo_param.warp_m, - m_algo_param.warp_n, - m_algo_param.warp_k}, - m_algo_param.stage, stream); - } else { - megdnn_assert(param.format == Format::NCHW32_NCHW4); - cutlass_wrapper:: - do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4< - true>( - args.src_tensor->compatible_ptr(), - filter_ptr, - args.bias_tensor->compatible_ptr(), - z_dev_ptr, - args.dst_tensor->compatible_ptr(), nullptr, - kern_param, nonlinear_mode, alpha, beta, gamma, - dst_scale, - cutlass_wrapper::GemmCoord{ - m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k}, - cutlass_wrapper::GemmCoord{m_algo_param.warp_m, - m_algo_param.warp_n, - m_algo_param.warp_k}, - m_algo_param.stage, stream); - } - } + float delta = 0.f, theta = 0.f, threshold = 0.f; + bool load_from_const = !(fh == 1 && fw == 1); + bool without_shared_load = (param.format == Format::NCHW32); + + const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop, + ConvType::kConvolution, + load_from_const, without_shared_load); + + execute_cutlass_conv_op( + op, args.src_tensor->raw_ptr, filter_ptr, args.bias_tensor->raw_ptr, + z_dev_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, + fw, ho, wo, ph, pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, + &theta, &threshold, &dst_scale, stream); + after_kernel_launch(); } @@ -249,9 +184,8 @@ size_t ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm:: return 0_z; } -SmallVector ConvBiasForwardImpl:: - AlgoInt8NCHW32IMMAImplicitGemm::deduce_preprocessed_filter_layout( - const SizeArgs& args) const { +SmallVector ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm:: + deduce_preprocessed_filter_layout(const SizeArgs& args) const { return {args.filter_layout->collapse_contiguous()}; } diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp index ea40e918..348c122d 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp @@ -6,14 +6,14 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ -#include "./algo.h" -#include "src/cuda/utils.h" -#include "src/cuda/convolution_helper/parameter.cuh" -#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" #include "src/common/conv_bias.h" +#include "src/cuda/conv_bias/algo.h" +#include "src/cuda/convolution_helper/parameter.cuh" +#include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; @@ -34,8 +34,7 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( bool available = true; auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - if (!check_bias_share_in_channel(*(args.bias_layout), - param.format)) + if (!check_bias_share_in_channel(*(args.bias_layout), param.format)) return false; bool valid_format = param.format == Format::NCHW4_NCHW32 && m_algo_param.threadblock_m % 32 == 0; @@ -48,7 +47,8 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm); valid_format |= param.format == Format::NCHW4; - if (!valid_format) return false; + if (!valid_format) + return false; size_t n = args.src_layout->operator[](0), ci = args.src_layout->operator[](1) * 4, hi = args.src_layout->operator[](2), @@ -170,16 +170,13 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( args.preprocessed_filter->tensors[0].raw_ptr); } - convolution::ConvParam kern_param; - kern_param.n = n, kern_param.co = co, kern_param.ci = ci, - kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, - kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, - kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, - kern_param.fw = fw; - float src_scale = args.src_layout->dtype.param().scale, filter_scale = args.filter_layout->dtype.param().scale; + + // \note these constants of cutlass epilogue will be passed to method + // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, + // a different dtype here results in undefined epilogue behaviors float alpha = src_scale * filter_scale; float beta = 1.f; float dst_scale = 1.f; @@ -192,13 +189,15 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( if (args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32) { megdnn_assert(args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED); - float bias_scale = args.bias_layout->dtype.param() - .scale; + float bias_scale = + args.bias_layout->dtype.param().scale; dst_scale = get_scale(args.dst_layout->dtype); alpha /= dst_scale, beta = bias_scale / dst_scale; } float delta = 0.f; + void* z_ptr = nullptr; if (args.z_layout->ndim > 0) { + z_ptr = args.z_tensor->raw_ptr; gamma = 1.f; if (args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) { megdnn_assert(args.dst_layout->dtype.category() == @@ -213,98 +212,20 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( delta = -z_zero * gamma; } } - uint32_t nonlinear_mode = static_cast(param.nonlineMode); - bool nonunity_kernel = !(fh == 1 && fw == 1); -#define DISPATCH(_nonunity_kernel) \ - if (nonunity_kernel == _nonunity_kernel) { \ - cb(_nonunity_kernel) \ - } - if (param.format == Format::NCHW4) { -#define cb(_nonunity_kernel) \ - cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4< \ - _nonunity_kernel>( \ - args.src_tensor->compatible_ptr(), filter_ptr, \ - args.bias_tensor->compatible_ptr(), \ - args.z_tensor->compatible_ptr(), \ - args.dst_tensor->compatible_ptr(), nullptr, kern_param, \ - nonlinear_mode, alpha, beta, gamma, dst_scale, \ - cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ - m_algo_param.threadblock_n, \ - m_algo_param.threadblock_k}, \ - cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ - m_algo_param.warp_n, \ - m_algo_param.warp_k}, \ - m_algo_param.stage, stream); - DISPATCH(true); - DISPATCH(false); -#undef cb - } else if (param.format == Format::NCHW4_NCHW) { -#define cb(_nonunity_kernel) \ - cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw< \ - _nonunity_kernel>( \ - args.src_tensor->compatible_ptr(), filter_ptr, \ - args.bias_tensor->compatible_ptr(), \ - args.z_tensor->compatible_ptr(), \ - args.dst_tensor->compatible_ptr(), nullptr, kern_param, \ - nonlinear_mode, alpha, beta, gamma, dst_scale, \ - cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ - m_algo_param.threadblock_n, \ - m_algo_param.threadblock_k}, \ - cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ - m_algo_param.warp_n, \ - m_algo_param.warp_k}, \ - m_algo_param.stage, stream); - DISPATCH(true); - DISPATCH(false); -#undef cb - } else if (param.format == Format::NCHW4_NHWC) { -#define cb(_signedness) \ - cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc< \ - _signedness>( \ - args.src_tensor->compatible_ptr(), filter_ptr, \ - args.bias_tensor->compatible_ptr(), \ - reinterpret_cast(args.z_tensor->raw_ptr), \ - reinterpret_cast(args.dst_tensor->raw_ptr), nullptr, \ - kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, \ - dst_scale, \ - cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ - m_algo_param.threadblock_n, \ - m_algo_param.threadblock_k}, \ - cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ - m_algo_param.warp_n, \ - m_algo_param.warp_k}, \ - m_algo_param.stage, stream); - if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4) { - cb(true); - } else { - megdnn_assert(args.dst_layout->dtype.enumv() == - DTypeEnum::Quantized4Asymm); - cb(false); - } -#undef cb - } else { - megdnn_assert(param.format == Format::NCHW4_NCHW32); -#define cb(_nonunity_kernel) \ - cutlass_wrapper:: \ - do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< \ - _nonunity_kernel>( \ - args.src_tensor->compatible_ptr(), filter_ptr, \ - args.bias_tensor->compatible_ptr(), \ - args.z_tensor->compatible_ptr(), \ - args.dst_tensor->compatible_ptr(), nullptr, \ - kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, \ - cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ - m_algo_param.threadblock_n, \ - m_algo_param.threadblock_k}, \ - cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ - m_algo_param.warp_n, \ - m_algo_param.warp_k}, \ - m_algo_param.stage, stream); - DISPATCH(true); - DISPATCH(false); -#undef cb -#undef DISPATCH - } + float threshold = 0.f; + bool load_from_const = !(fh == 1 && fw == 1); + bool without_shared_load = false; + + const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop, + ConvType::kConvolution, + load_from_const, without_shared_load); + + execute_cutlass_conv_op( + op, args.src_tensor->raw_ptr, filter_ptr, args.bias_tensor->raw_ptr, + z_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, fw, + ho, wo, ph, pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, + &theta, &threshold, &dst_scale, stream); + after_kernel_launch(); } diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nchw64_imma.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nchw64_imma.cpp index d94833ee..1be3ea4e 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nchw64_imma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nchw64_imma.cpp @@ -10,8 +10,7 @@ * implied. */ -#include "./algo.h" -#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" +#include "src/cuda/conv_bias/algo.h" #include "src/cuda/conv_bias/reduce_filter.cuh" #include "src/cuda/utils.h" @@ -120,32 +119,15 @@ ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::get_constants( delta = -z_zero * gamma; } - return {alpha, beta, gamma, delta, theta}; -} + // identity epilogue has no theta: + // alpha * accumulator + beta * bias + gamma * source + delta + if (args.opr->param().nonlineMode == + param::ConvBias::NonlineMode::IDENTITY) { + delta += theta; + theta = 0.f; + } -void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::do_exec( - const ExecArgs& args, void* filter_ptr, void* bias_ptr, void* z_ptr, - ConvParam kern_param, uint32_t nonlinear_mode, float alpha, float beta, - float gamma, float delta, float theta, cudaStream_t stream) const { - float dst_scale = - args.dst_layout->dtype.param().scale; - uint8_t src_zero = - args.src_layout->dtype.param().zero_point; - cutlass_wrapper::GemmCoord threadblock_shape{m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k}; - - cutlass_wrapper::GemmCoord warp_shape{ - m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k}; - cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< - true>(reinterpret_cast(args.src_tensor->raw_ptr), - reinterpret_cast(filter_ptr), - reinterpret_cast(bias_ptr), - reinterpret_cast(z_ptr), - reinterpret_cast(args.dst_tensor->raw_ptr), nullptr, - kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, - dst_scale, src_zero, threadblock_shape, warp_shape, - m_algo_param.stage, stream); + return {alpha, beta, gamma, delta, theta}; } void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::update_bias( diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nhwc_imma.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nhwc_imma.cpp index b074a48d..d271449e 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nhwc_imma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nhwc_imma.cpp @@ -10,8 +10,7 @@ * implied. */ -#include "./algo.h" -#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" +#include "src/cuda/conv_bias/algo.h" #include "src/cuda/conv_bias/reduce_filter.cuh" #include "src/cuda/utils.h" @@ -121,44 +120,15 @@ ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::get_constants( delta = -z_zero * gamma; } - return {alpha, beta, gamma, delta, theta}; -} - -void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::do_exec( - const ExecArgs& args, void* filter_ptr, void* bias_ptr, void* z_ptr, - ConvParam kern_param, uint32_t nonlinear_mode, float alpha, float beta, - float gamma, float delta, float theta, cudaStream_t stream) const { - float dst_scale = - args.dst_layout->dtype.param().scale; - uint8_t src_zero = - args.src_layout->dtype.param().zero_point; - cutlass_wrapper::GemmCoord threadblock_shape{m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k}; - - cutlass_wrapper::GemmCoord warp_shape{ - m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k}; - if (kern_param.fh == 1 && kern_param.fw == 1) { - cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( - reinterpret_cast(args.src_tensor->raw_ptr), - reinterpret_cast(filter_ptr), - reinterpret_cast(bias_ptr), - reinterpret_cast(z_ptr), - reinterpret_cast(args.dst_tensor->raw_ptr), nullptr, - kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, - dst_scale, src_zero, threadblock_shape, warp_shape, - m_algo_param.access_size, m_algo_param.stage, stream); - } else { - cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( - reinterpret_cast(args.src_tensor->raw_ptr), - reinterpret_cast(filter_ptr), - reinterpret_cast(bias_ptr), - reinterpret_cast(z_ptr), - reinterpret_cast(args.dst_tensor->raw_ptr), nullptr, - kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, - dst_scale, src_zero, threadblock_shape, warp_shape, - m_algo_param.access_size, m_algo_param.stage, stream); + // identity epilogue has no theta: + // alpha * accumulator + beta * bias + gamma * source + delta + if (args.opr->param().nonlineMode == + param::ConvBias::NonlineMode::IDENTITY) { + delta += theta; + theta = 0.f; } + + return {alpha, beta, gamma, delta, theta}; } void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::update_bias( diff --git a/dnn/src/cuda/conv_bias/opr_impl.h b/dnn/src/cuda/conv_bias/opr_impl.h index 43e612d2..4af7896b 100644 --- a/dnn/src/cuda/conv_bias/opr_impl.h +++ b/dnn/src/cuda/conv_bias/opr_impl.h @@ -57,6 +57,7 @@ public: class AlgoBatchedMatmul; class AlgoGroupConvGeneral; class AlgoQUInt4x4x32WMMA; + class AlgoCutlassConvolutionBase; class AlgoInt8CHWN4DotProdImplicitGemm; class AlgoInt8NCHW4DotProdImplicitGemm; class AlgoInt8CHWN4IMMAImplicitGemm; diff --git a/dnn/src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cu b/dnn/src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cu deleted file mode 100644 index 63ca450e..00000000 --- a/dnn/src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cu +++ /dev/null @@ -1,100 +0,0 @@ -/** - * \file src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cu - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ -// ignore warning of cutlass -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#pragma GCC diagnostic ignored "-Wstrict-aliasing" - -#if !MEGDNN_TEGRA_X1 -#include "cutlass/convolution/device/convolution.h" -#endif -#include "src/common/opr_param_defs_enumv.cuh" -#include "src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh" -#pragma GCC diagnostic pop - -using namespace megdnn; -using namespace cuda; -using namespace cutlass_wrapper; - -/* ================ cutlass kernel wrapper for nchw4 layout ================= */ -#if MEGDNN_TEGRA_X1 -void megdnn::cuda::cutlass_wrapper::do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4( - const int8_t* /* d_src */, const int8_t* /* d_filter */, - int8_t* /* d_dst */, int* /* workspace */, - const convolution::ConvParam& /* param */, float /* alpha */, - const GemmCoord& /* threadblock_shape */, - const GemmCoord& /* warp_shape */, int /* stages */, - cudaStream_t /* stream */) {} -#else -void megdnn::cuda::cutlass_wrapper::do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4( - const int8_t* d_src, const int8_t* d_filter, int8_t* d_dst, - int* workspace, const convolution::ConvParam& param, float alpha, - const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, - int stages, cudaStream_t stream) { -#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ - threadblock_k_, warp_m_, warp_n_, \ - warp_k_, stage_, aligned_) \ - if (threadblock_shape.m() == threadblock_m_ && \ - threadblock_shape.n() == threadblock_n_ && \ - threadblock_shape.k() == threadblock_k_ && \ - warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ - warp_shape.k() == warp_k_ && stages == stage_) { \ - using ThreadBlockShape = \ - cutlass::gemm::GemmShape; \ - using WarpShape = cutlass::gemm::GemmShape; \ - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ - using Deconvolution = cutlass::conv::device::Deconvolution< \ - int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ - cutlass::layout::TensorKxRSCx<4>, ElementOutput, \ - cutlass::layout::TensorNCxHWx<4>, int32_t, \ - cutlass::layout::TensorNCxHWx<4>, int32_t, \ - cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \ - ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ - cutlass::conv::threadblock:: \ - ConvolutionDgradNCxHWxThreadblockSwizzle, \ - stage_, 4, aligned_, true, cutlass::arch::OpMultiplyAdd>; \ - typename Deconvolution::ConvolutionParameter conv_param( \ - param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ - param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ - param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ - return cutlass_deconvolution_wrapper( \ - d_src, d_filter, nullptr, nullptr, d_dst, workspace, \ - conv_param, epilogue, stream); \ - } -#define DISPATCH_KERNEL \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 64, 16, 2, 4); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ - megdnn_assert(false, \ - "unsupported threadblock shape (%dx%dx%d) and warp shape " \ - "(%dx%dx%d)", \ - threadblock_shape.m(), threadblock_shape.n(), \ - threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ - warp_shape.k()); - using ElementOutput = int8_t; - using ElementAccumulator = int32_t; - using ElementBias = int32_t; - using ElementCompute = float; - using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< - ElementOutput, 4, ElementAccumulator, ElementBias, ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, 0, 0}; - DISPATCH_KERNEL; - -#undef DISPATCH_KERNEL_WITH_TILE_SHAPE -#undef DISPATCH_KERNEL -} -#endif - -// vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh b/dnn/src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh deleted file mode 100644 index 35961673..00000000 --- a/dnn/src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh +++ /dev/null @@ -1,44 +0,0 @@ -/** - * \file src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ -#pragma once -#include "cutlass/gemm/gemm.h" -#include "src/cuda/convolution_helper/parameter.cuh" -#include "src/cuda/utils.cuh" - -namespace megdnn { -namespace cuda { -namespace cutlass_wrapper { - -using GemmCoord = cutlass::gemm::GemmCoord; - -template -void cutlass_deconvolution_wrapper( - const typename Convolution::ElementSrc* d_src, - const typename Convolution::ElementFilter* d_filter, - const typename Convolution::ElementBias* d_bias, - const typename Convolution::ElementDst* d_z, - typename Convolution::ElementDst* d_dst, int* workspace, - typename Convolution::ConvolutionParameter const& conv_param, - typename Convolution::EpilogueOutputOp::Params const& epilogue, - cudaStream_t stream); - -void do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4( - const int8_t* d_src, const int8_t* d_filter, int8_t* d_dst, - int* workspace, const convolution::ConvParam& param, float alpha, - const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, - int stages, cudaStream_t stream); - -} // namespace cutlass_wrapper -} // namespace cuda -} // namespace megdnn - -// vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl b/dnn/src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl deleted file mode 100644 index 349b392a..00000000 --- a/dnn/src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl +++ /dev/null @@ -1,62 +0,0 @@ -/** - * \file - * dnn/src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ -#include "cutlass/convolution/device/convolution.h" -#include "src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh" - -using namespace megdnn; -using namespace cuda; -using namespace cutlass_wrapper; - -template -void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper( - const typename Deconvolution::ElementSrc* d_src, - const typename Deconvolution::ElementFilter* d_filter, - const typename Deconvolution::ElementBias* d_bias, - const typename Deconvolution::ElementDst* d_z, - typename Deconvolution::ElementDst* d_dst, int* workspace, - typename Deconvolution::ConvolutionParameter const& conv_param, - typename Deconvolution::EpilogueOutputOp::Params const& epilogue, - cudaStream_t stream) { - typename Deconvolution::TensorRefSrc tensor_src{ - const_cast(d_src), - Deconvolution::LayoutSrc::packed( - {conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; - typename Deconvolution::TensorRefFilter tensor_filter{ - const_cast(d_filter), - Deconvolution::LayoutFilter::packed( - {conv_param.K, conv_param.R, conv_param.S, conv_param.C})}; - typename Deconvolution::TensorRefBias tensor_bias{ - const_cast(d_bias), - Deconvolution::LayoutBias::packed({1, 1, 1, conv_param.K})}; - typename Deconvolution::TensorRefDst tensor_z{ - const_cast(d_z), - Deconvolution::LayoutDst::packed( - {conv_param.N, conv_param.H, conv_param.W, conv_param.C})}; - typename Deconvolution::TensorRefDst tensor_dst{ - d_dst, - Deconvolution::LayoutDst::packed( - {conv_param.N, conv_param.H, conv_param.W, conv_param.C})}; - typename Deconvolution::Arguments arguments{conv_param, - tensor_src.non_const_ref(), - tensor_filter.non_const_ref(), - tensor_bias.non_const_ref(), - tensor_z.non_const_ref(), - tensor_dst.non_const_ref(), - epilogue}; - Deconvolution deconv_op; - cutlass_check(deconv_op.initialize(arguments, workspace)); - cutlass_check(deconv_op(stream)); - after_kernel_launch(); -} - -// vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp b/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp index 15380a34..c4f631e6 100644 --- a/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp +++ b/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp @@ -1,5 +1,6 @@ /** - * \file dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp + * \file + * dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,11 +11,11 @@ * implied. */ -#include "./algo.h" -#include "src/cuda/utils.h" -#include "src/cuda/convolution_helper/parameter.cuh" -#include "src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh" +#include "src/cuda/convolution/backward_data/algo.h" #include "src/cuda/convolution/backward_data/deconv_int8_helper.cuh" +#include "src/cuda/convolution_helper/parameter.cuh" +#include "src/cuda/cutlass/singleton.h" +#include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; @@ -70,6 +71,7 @@ size_t ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm:: void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( const ExecArgs& args) const { + auto&& param = args.opr->param(); auto&& fm = args.filter_meta; size_t n = args.diff_layout->operator[](0), co = args.diff_layout->operator[](1) * 4, @@ -81,6 +83,7 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( size_t fh = fm.spatial[0], fw = fm.spatial[1]; size_t sh = fm.stride[0], sw = fm.stride[1]; size_t ph = fm.padding[0], pw = fm.padding[1]; + size_t dh = param.dilate_h, dw = param.dilate_w; auto&& stream = cuda_stream(args.opr->handle()); @@ -93,12 +96,6 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( filter_ptr, args.filter_tensor->compatible_ptr(), co, ci, fh, fw, stream); } - convolution::ConvParam kern_param; - kern_param.n = n, kern_param.co = co, kern_param.ci = ci, - kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, - kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, - kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, - kern_param.fw = fw; float diff_scale = args.diff_layout->dtype.param().scale, @@ -106,17 +103,60 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( args.filter_layout->dtype.param().scale, grad_scale = args.grad_layout->dtype.param().scale; - float alpha = diff_scale * filter_scale / grad_scale; - cutlass_wrapper::do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4( - args.diff_tensor->compatible_ptr(), filter_ptr, - args.grad_tensor->compatible_ptr(), nullptr, kern_param, - alpha, - cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k}, - cutlass_wrapper::GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n, - m_algo_param.warp_k}, - m_algo_param.stage, stream); + + // \note these constants of cutlass epilogue will be passed to struct + // `ConvolutionArguments` by pointer and interpreted as ElementCompute*, a + // different dtype here results in undefined epilogue behaviors + float alpha = diff_scale * filter_scale / grad_scale, beta = 0.f, + gamma = 0.f, delta = 0.f; + + using namespace cutlass::library; + + // only use 16x64x8_16x64x8_2stages impl + ConvolutionKey key{ + cutlass::conv::Operator::kDgrad, + NumericTypeID::kS8, + LayoutTypeID::kTensorNC4HW4, + NumericTypeID::kS8, + LayoutTypeID::kTensorK4RSC4, + NumericTypeID::kS8, + LayoutTypeID::kTensorNC4HW4, + NumericTypeID::kS32, + LayoutTypeID::kTensorNC4HW4, + cutlass::conv::ConvType::kConvolution, + m_algo_param.threadblock_m, + m_algo_param.threadblock_n, + m_algo_param.threadblock_k, + m_algo_param.warp_m, + m_algo_param.warp_n, + m_algo_param.warp_k, + 1, + 1, + 4, + cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp, + m_algo_param.stage, + true, + false}; + + const Operation* op = Singleton::get().operation_table.find_op(key); + + // gcc prints warnings when size_t values are implicitly narrowed to int + cutlass::conv::Conv2dProblemSize problem_size{ + int(n), int(hi), int(wi), int(ci), + int(co), int(fh), int(fw), int(ho), + int(wo), int(ph), int(pw), int(sh), + int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation}; + + cutlass::library::ConvolutionArguments conv_args{ + problem_size, args.diff_tensor->compatible_ptr(), + filter_ptr, nullptr, + nullptr, args.grad_tensor->compatible_ptr(), + &alpha, &beta, + &gamma, &delta, + nullptr, nullptr, + nullptr, nullptr}; + + cutlass_check(op->run(&conv_args, nullptr, stream)); after_kernel_launch(); } diff --git a/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw_dp4a.cpp b/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw_dp4a.cpp index ccf466ab..f4518802 100644 --- a/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw_dp4a.cpp +++ b/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw_dp4a.cpp @@ -11,16 +11,16 @@ * implied. */ -#include "./algo.h" -#include "src/cuda/utils.h" +#include "src/cuda/convolution/backward_data/algo.h" #include "src/cuda/convolution_helper/parameter.cuh" -#include "src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh" +#include "src/cuda/cutlass/singleton.h" +#include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; -bool ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm:: - is_available(const SizeArgs& args) const { +bool ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::is_available( + const SizeArgs& args) const { auto&& fm = args.filter_meta; if (fm.format != Param::Format::NCHW) return false; @@ -42,7 +42,8 @@ bool ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm:: // TODO support group deconv int8 available &= (fm.group == 1); // ic and oc must be multiples of 4 - available &= ((fm.group * fm.icpg) % 4 == 0 && (fm.group * fm.ocpg) % 4 == 0); + available &= + ((fm.group * fm.icpg) % 4 == 0 && (fm.group * fm.ocpg) % 4 == 0); // mode must be cross correlation available &= !fm.should_flip; // mode must be 2D @@ -73,6 +74,7 @@ size_t ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm:: void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( const ExecArgs& args) const { + auto&& param = args.opr->param(); auto&& fm = args.filter_meta; size_t n = args.diff_layout->operator[](0), co = args.diff_layout->operator[](1), @@ -84,6 +86,7 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( size_t fh = fm.spatial[0], fw = fm.spatial[1]; size_t sh = fm.stride[0], sw = fm.stride[1]; size_t ph = fm.padding[0], pw = fm.padding[1]; + size_t dh = param.dilate_h, dw = param.dilate_w; auto&& stream = cuda_stream(args.opr->handle()); @@ -120,26 +123,63 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( } int8_t* inner_grad_ptr = reinterpret_cast(bundle.get(2)); - convolution::ConvParam kern_param; - kern_param.n = n, kern_param.co = co, kern_param.ci = ci, - kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, - kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, - kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, - kern_param.fw = fw; - float diff_scale = args.diff_layout->dtype.param().scale, filter_scale = args.filter_layout->dtype.param().scale, grad_scale = args.grad_layout->dtype.param().scale; - float alpha = diff_scale * filter_scale / grad_scale; + + // \note these constants of cutlass epilogue will be passed to struct + // `ConvolutionArguments` by pointer and interpreted as ElementCompute*, a + // different dtype here results in undefined epilogue behaviors + float alpha = diff_scale * filter_scale / grad_scale, beta = 0.f, + gamma = 0.f, delta = 0.f; + + using namespace cutlass::library; // only use 16x64x8_16x64x8_2stages impl - cutlass_wrapper::do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4( - inner_diff_ptr, inner_filter_ptr, inner_grad_ptr, nullptr, - kern_param, alpha, cutlass_wrapper::GemmCoord{16, 64, 8}, - cutlass_wrapper::GemmCoord{16, 64, 8}, 2, stream); + ConvolutionKey key{ + cutlass::conv::Operator::kDgrad, + NumericTypeID::kS8, + LayoutTypeID::kTensorNC4HW4, + NumericTypeID::kS8, + LayoutTypeID::kTensorK4RSC4, + NumericTypeID::kS8, + LayoutTypeID::kTensorNC4HW4, + NumericTypeID::kS32, + LayoutTypeID::kTensorNC4HW4, + cutlass::conv::ConvType::kConvolution, + 16, + 64, + 8, + 16, + 64, + 8, + 1, + 1, + 4, + cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp, + 2, + true, + false}; + + const Operation* op = Singleton::get().operation_table.find_op(key); + + // gcc prints warnings when size_t values are implicitly narrowed to int + cutlass::conv::Conv2dProblemSize problem_size{ + int(n), int(hi), int(wi), int(ci), + int(co), int(fh), int(fw), int(ho), + int(wo), int(ph), int(pw), int(sh), + int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation}; + + cutlass::library::ConvolutionArguments conv_args{ + problem_size, inner_diff_ptr, inner_filter_ptr, nullptr, + nullptr, inner_grad_ptr, &alpha, &beta, + &gamma, &delta, nullptr, nullptr, + nullptr, nullptr}; + + cutlass_check(op->run(&conv_args, nullptr, stream)); after_kernel_launch(); diff --git a/dnn/src/cuda/cutlass/arch_mappings.h b/dnn/src/cuda/cutlass/arch_mappings.h new file mode 100644 index 00000000..bac4c7e6 --- /dev/null +++ b/dnn/src/cuda/cutlass/arch_mappings.h @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + *modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + *notice, this list of conditions and the following disclaimer in the + *documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its + *contributors may be used to endorse or promote products derived from this + *software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, + *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING + *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/** + * \file dnn/src/cuda/cutlass/arch_mappings.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ArchMap; + +template <> +struct ArchMap { + static int const kMin = 50; + static int const kMax = 1024; +}; + +template <> +struct ArchMap { + static int const kMin = 60; + static int const kMax = 1024; +}; + +template <> +struct ArchMap { + static int const kMin = 61; + static int const kMax = 1024; +}; + +template <> +struct ArchMap { + static int const kMin = 70; + static int const kMax = 1024; +}; + +template <> +struct ArchMap { + static int const kMin = 70; + static int const kMax = 75; +}; + +template +struct ArchMap { + static int const kMin = 75; + static int const kMax = 1024; +}; + +template +struct ArchMap { + static int const kMin = 80; + static int const kMax = 1024; +}; + +template +struct ArchMap { + static int const kMin = 86; + static int const kMax = 1024; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/dnn/src/cuda/cutlass/convolution_operation.h b/dnn/src/cuda/cutlass/convolution_operation.h new file mode 100644 index 00000000..663bf8b9 --- /dev/null +++ b/dnn/src/cuda/cutlass/convolution_operation.h @@ -0,0 +1,307 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + *modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + *notice, this list of conditions and the following disclaimer in the + *documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its + *contributors may be used to endorse or promote products derived from this + *software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, + *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING + *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/** + * \file dnn/src/cuda/cutlass/convolution_operation.h + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "cutlass/convolution/device/convolution.h" +#include "src/cuda/cutlass/library_internal.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class ConvolutionOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementSrc = typename Operator::ElementSrc; + using LayoutSrc = typename Operator::LayoutSrc; + using ElementFilter = typename Operator::ElementFilter; + using LayoutFilter = typename Operator::LayoutFilter; + using ElementDst = typename Operator::ElementDst; + using LayoutDst = typename Operator::LayoutDst; + using ElementBias = typename Operator::ElementBias; + using LayoutBias = typename Operator::LayoutBias; + using ElementAccumulator = typename Operator::ElementAccumulator; + + ConvolutionOperationBase(char const* name = "unknown_convolution") { + m_description.name = name; + m_description.provider = Provider::kCUTLASS; + m_description.kind = OperationKind::kConvolution; + m_description.conv_op = Operator::kConvolutionalOperator; + + m_description.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + m_description.tile_description.threadblock_stages = Operator::kStages; + + m_description.tile_description.warp_count = + make_Coord(Operator::ConvolutionKernel::WarpCount::kM, + Operator::ConvolutionKernel::WarpCount::kN, + Operator::ConvolutionKernel::WarpCount::kK); + + m_description.tile_description.math_instruction.instruction_shape = + make_Coord(Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + m_description.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + m_description.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + m_description.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + m_description.tile_description.minimum_compute_capability = + ArchMap::kMin; + + m_description.tile_description.maximum_compute_capability = + ArchMap::kMax; + + m_description.src = make_TensorDescription( + Operator::kAlignmentSrc); + m_description.filter = + make_TensorDescription( + Operator::kAlignmentFilter); + m_description.dst = make_TensorDescription( + Operator::kAlignmentDst); + m_description.bias = make_TensorDescription( + Operator::kAlignmentDst); + + m_description.convolution_type = Operator::kConvolutionType; + m_description.arch_tag = ArchTagMap::kId; + + m_description.epilogue_type = Operator::EpilogueOutputOp::kType; + m_description.epilogue_count = Operator::EpilogueOutputOp::kCount; + + m_description.threadblock_swizzle = ThreadblockSwizzleMap< + typename Operator::ThreadblockSwizzle>::kId; + + m_description.need_load_from_const_mem = + Operator::kNeedLoadFromConstMem; + m_description.gemm_mode = Operator::kGemmMode; + m_description.without_shared_load = Operator::kWithoutSharedLoad; + } + + virtual OperationDescription const& description() const { + return m_description; + } + +protected: + ConvolutionDescription m_description; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct init_epilogue_param_; + +template +struct init_epilogue_param_ { + using ElementCompute = typename EpilogueOp::ElementCompute; + typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { + return {*static_cast(conv_args->alpha), + *static_cast(conv_args->beta), + *static_cast(conv_args->gamma), + *static_cast(conv_args->delta)}; + } +}; + +template +struct init_epilogue_param_< + EpilogueOp, epilogue::EpilogueType::kBiasAddLinearCombinationClamp> { + using ElementCompute = typename EpilogueOp::ElementCompute; + typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { + return {*static_cast(conv_args->alpha), + *static_cast(conv_args->beta), + *static_cast(conv_args->gamma), + *static_cast(conv_args->delta)}; + } +}; + +template +struct init_epilogue_param_< + EpilogueOp, epilogue::EpilogueType::kBiasAddLinearCombinationRelu> { + using ElementCompute = typename EpilogueOp::ElementCompute; + typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { + return {*static_cast(conv_args->alpha), + *static_cast(conv_args->beta), + *static_cast(conv_args->gamma), + *static_cast(conv_args->threshold), + *static_cast(conv_args->delta), + *static_cast(conv_args->theta)}; + } +}; + +template +struct init_epilogue_param_< + EpilogueOp, + epilogue::EpilogueType::kBiasAddLinearCombinationReluClamp> { + using ElementCompute = typename EpilogueOp::ElementCompute; + typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { + return {*static_cast(conv_args->alpha), + *static_cast(conv_args->beta), + *static_cast(conv_args->gamma), + *static_cast(conv_args->threshold), + *static_cast(conv_args->delta), + *static_cast(conv_args->theta)}; + } +}; + +template +struct init_epilogue_param_< + EpilogueOp, epilogue::EpilogueType::kBiasAddLinearCombinationHSwish> { + using ElementCompute = typename EpilogueOp::ElementCompute; + typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { + return {*static_cast(conv_args->alpha), + *static_cast(conv_args->beta), + *static_cast(conv_args->gamma), + *static_cast(conv_args->scale), + *static_cast(conv_args->delta), + *static_cast(conv_args->theta)}; + } +}; + +template +struct init_epilogue_param_< + EpilogueOp, + epilogue::EpilogueType::kBiasAddLinearCombinationHSwishClamp> { + using ElementCompute = typename EpilogueOp::ElementCompute; + typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { + return {*static_cast(conv_args->alpha), + *static_cast(conv_args->beta), + *static_cast(conv_args->gamma), + *static_cast(conv_args->scale), + *static_cast(conv_args->delta), + *static_cast(conv_args->theta)}; + } +}; + +} // namespace detail + +template +struct init_epilogue_param + : public detail::init_epilogue_param_ {}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class ConvolutionOperation : public ConvolutionOperationBase { +public: + using Operator = Operator_; + using ElementSrc = typename Operator::ElementSrc; + using LayoutSrc = typename Operator::LayoutSrc; + using ElementFilter = typename Operator::ElementFilter; + using LayoutFilter = typename Operator::LayoutFilter; + using ElementBias = typename Operator::ElementBias; + using LayoutBias = typename Operator::LayoutBias; + using ElementDst = typename Operator::ElementDst; + using LayoutDst = typename Operator::LayoutDst; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + + ConvolutionOperation(char const* name = "unknown_gemm") + : ConvolutionOperationBase(name) {} + + virtual Status run(void const* arguments_ptr, + void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + cutlass::conv::Operator conv_op = this->m_description.conv_op; + ConvolutionArguments const* conv_args = + reinterpret_cast(arguments_ptr); + const auto& ps = conv_args->problem_size; + + OperatorArguments args; + args.problem_size = ps; + args.ref_src = { + static_cast(const_cast(conv_args->src)), + LayoutSrc::packed(implicit_gemm_tensor_a_extent(conv_op, ps))}; + args.ref_filter = {static_cast( + const_cast(conv_args->filter)), + LayoutFilter::packed( + implicit_gemm_tensor_b_extent(conv_op, ps))}; + args.ref_bias = { + static_cast(const_cast(conv_args->bias)), + LayoutBias::packed( + implicit_gemm_tensor_bias_extent(conv_op, ps))}; + args.ref_z = { + static_cast(const_cast(conv_args->z)), + LayoutDst::packed(implicit_gemm_tensor_c_extent(conv_op, ps))}; + args.ref_dst = { + static_cast(conv_args->dst), + LayoutDst::packed(implicit_gemm_tensor_c_extent(conv_op, ps))}; + + args.output_op = + init_epilogue_param().get( + conv_args); + + if (conv_args->extra_param) { + args.extra_param = + *reinterpret_cast( + conv_args->extra_param); + } + + Operator op; + Status status = op.initialize(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + return op.run(stream); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/dnn/src/cuda/cutlass/gemm_operation.h b/dnn/src/cuda/cutlass/gemm_operation.h new file mode 100644 index 00000000..20c5b9a0 --- /dev/null +++ b/dnn/src/cuda/cutlass/gemm_operation.h @@ -0,0 +1,202 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + *modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + *notice, this list of conditions and the following disclaimer in the + *documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its + *contributors may be used to endorse or promote products derived from this + *software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, + *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING + *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/** + * \file dnn/src/cuda/cutlass/gemm_operation.h + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "cutlass/gemm/device/gemm.h" +#include "src/cuda/cutlass/library_internal.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Check whether Operator has member ReductionKernel using SFINAE (Substitution +/// Failure Is Not An Error) +template +struct split_k_mode { + template + static char check(typename T::ReductionKernel*); + + template + static int check(...); + + SplitKMode operator()() { + if (sizeof(check(0)) == sizeof(char)) { + // cutlass::gemm::device::GemmSplitKParallel + return SplitKMode::kParallel; + } else { + // cutlass::gemm::device::Gemm + return SplitKMode::kNone; + } + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + + GemmOperationBase(char const* name = "unknown_gemm") { + m_description.name = name; + m_description.provider = Provider::kCUTLASS; + m_description.kind = OperationKind::kGemm; + m_description.gemm_kind = GemmKind::kGemm; + + m_description.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + m_description.tile_description.threadblock_stages = Operator::kStages; + + m_description.tile_description.warp_count = + make_Coord(Operator::GemmKernel::WarpCount::kM, + Operator::GemmKernel::WarpCount::kN, + Operator::GemmKernel::WarpCount::kK); + + m_description.tile_description.math_instruction.instruction_shape = + make_Coord(Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + m_description.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + m_description.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + m_description.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + m_description.tile_description.minimum_compute_capability = + ArchMap::kMin; + + m_description.tile_description.maximum_compute_capability = + ArchMap::kMax; + + m_description.A = make_TensorDescription( + Operator::kAlignmentA); + m_description.B = make_TensorDescription( + Operator::kAlignmentB); + m_description.C = make_TensorDescription( + Operator::kAlignmentC); + + m_description.stages = Operator::kStages; + + split_k_mode mode; + m_description.split_k_mode = mode(); + } + + virtual OperationDescription const& description() const { + return m_description; + } + +protected: + GemmDescription m_description; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmOperation : public GemmOperationBase { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + + GemmOperation(char const* name = "unknown_gemm") + : GemmOperationBase(name) {} + + virtual Status run(void const* arguments_ptr, + void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + GemmArguments const* gemm_args = + reinterpret_cast(arguments_ptr); + + OperatorArguments args; + args.problem_size = gemm_args->problem_size; + args.ref_A = {static_cast(gemm_args->A), + int(gemm_args->lda)}; + args.ref_B = {static_cast(gemm_args->B), + int(gemm_args->ldb)}; + args.ref_C = {static_cast(gemm_args->C), + int(gemm_args->ldc)}; + args.ref_D = {static_cast(gemm_args->D), + int(gemm_args->ldd)}; + args.split_k_slices = gemm_args->split_k_slices; + + args.epilogue = {*static_cast(gemm_args->alpha), + *static_cast(gemm_args->beta)}; + + Operator op; + Status status = op.initialize(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + return op.run(stream); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/dnn/src/cuda/cutlass/initialize_all.cu b/dnn/src/cuda/cutlass/initialize_all.cu new file mode 100644 index 00000000..1e67822d --- /dev/null +++ b/dnn/src/cuda/cutlass/initialize_all.cu @@ -0,0 +1,76 @@ +/*************************************************************************************************** + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + *modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + *notice, this list of conditions and the following disclaimer in the + *documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its + *contributors may be used to endorse or promote products derived from this + *software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, + *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING + *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/** + * \file dnn/src/cuda/cutlass/initialize_all.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/cuda/cutlass/manifest.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if __CUDACC_VER_MAJOR__ > 9 || \ + (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) + +void initialize_all_gemm_simt_operations(Manifest& manifest); +void initialize_all_conv2d_simt_operations(Manifest& manifest); +void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest); +void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest); +void initialize_all_deconv_simt_operations(Manifest& manifest); + +void initialize_all(Manifest& manifest) { + initialize_all_gemm_simt_operations(manifest); + initialize_all_conv2d_simt_operations(manifest); + initialize_all_conv2d_tensorop8816_operations(manifest); + initialize_all_conv2d_tensorop8832_operations(manifest); + initialize_all_deconv_simt_operations(manifest); +} + +#else + +void initialize_all(Manifest& manifest) {} + +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/dnn/src/cuda/cutlass/library.h b/dnn/src/cuda/cutlass/library.h new file mode 100644 index 00000000..9907fad1 --- /dev/null +++ b/dnn/src/cuda/cutlass/library.h @@ -0,0 +1,541 @@ +/*************************************************************************************************** + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + *modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + *notice, this list of conditions and the following disclaimer in the + *documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its + *contributors may be used to endorse or promote products derived from this + *software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, + *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING + *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/** + * \file dnn/src/cuda/cutlass/library.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wreorder" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wunused-parameter" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/tensor_coord.h" + +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/epilogue/epilogue.h" +#include "cutlass/gemm/gemm.h" + +#pragma GCC diagnostic pop + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Layout type identifier +enum class LayoutTypeID { + kUnknown, + kColumnMajor, + kRowMajor, + kColumnMajorInterleavedK2, + kRowMajorInterleavedK2, + kColumnMajorInterleavedK4, + kRowMajorInterleavedK4, + kColumnMajorInterleavedK16, + kRowMajorInterleavedK16, + kColumnMajorInterleavedK32, + kRowMajorInterleavedK32, + kColumnMajorInterleavedK64, + kRowMajorInterleavedK64, + kTensorNCHW, + kTensorNCDHW, + kTensorNHWC, + kTensorNDHWC, + kTensorNC4HW4, + kTensorC4RSK4, + kTensorNC8HW8, + kTensorC8RSK8, + kTensorNC16HW16, + kTensorC16RSK16, + kTensorNC32HW32, + kTensorC32RSK32, + kTensorNC64HW64, + kTensorC64RSK64, + kTensorK4RSC4, + kInvalid +}; + +/// Numeric data type +enum class NumericTypeID { + kUnknown, + kVoid, + kB1, + kU2, + kU4, + kU8, + kU16, + kU32, + kU64, + kS2, + kS4, + kS8, + kS16, + kS32, + kS64, + kF16, + kBF16, + kTF32, + kF32, + kF64, + kCF16, + kCBF16, + kCF32, + kCTF32, + kCF64, + kCS2, + kCS4, + kCS8, + kCS16, + kCS32, + kCS64, + kCU2, + kCU4, + kCU8, + kCU16, + kCU32, + kCU64, + kInvalid +}; + +/// Enumerated type describing a transformation on a complex value. +enum class ComplexTransform { kNone, kConjugate, kInvalid }; + +/// Providers +enum class Provider { + kNone, + kCUTLASS, + kReferenceHost, + kReferenceDevice, + kCUBLAS, + kCUDNN, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enumeration indicating the kind of operation +enum class OperationKind { + kGemm, + kConv2d, + kConv3d, + kConvolution, + kEqGemm, + kSparseGemm, + kReduction, + kInvalid +}; + +/// Enumeration indicating whether scalars are in host or device memory +enum class ScalarPointerMode { kHost, kDevice, kInvalid }; + +/// Describes how reductions are performed across threadblocks +enum class SplitKMode { kNone, kSerial, kParallel, kParallelSerial, kInvalid }; + +/// Indicates the classificaition of the math instruction +enum class OpcodeClassID { + kSimt, + kTensorOp, + kWmmaTensorOp, + kSparseTensorOp, + kInvalid +}; + +enum class ArchTagID { + kSm50, + kSm60, + kSm61, + kSm70, + kSm72, + kSm75, + kSm80, + kSm86, + kInvalid +}; + +enum class MathOperationID { + kAdd, + kMultiplyAdd, + kMultiplyAddSaturate, + kMultiplyAddFastBF16, + kMultiplyAddFastF16, + kMultiplyAddComplex, + kMultiplyAddGaussianComplex, + kXorPopc, + kInvalid +}; + +enum class ThreadblockSwizzleID { + kGemmIdentity, + kGemmHorizontal, + kGemmBatchedIdentity, + kGemmSplitKIdentity, + kGemmSplitKHorizontal, + kGemvBatchedStridedDefault, + kGemvBatchedStridedReduction, + kConvolutionFpropCxRSKx, + kConvolutionDgradCxRSKx, + kConvolutionFpropNCxHWx, + kConvolutionFpropTrans, + kConvolutionDgradNCxHWx, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enumeration indicating what kind of GEMM operation to perform +enum class GemmKind { + kGemm, + kSparse, + kUniversal, + kPlanarComplex, + kPlanarComplexArray, + kInvalid +}; + +/// Mode of Universal GEMM +using GemmUniversalMode = cutlass::gemm::GemmUniversalMode; + +/// Enumeration indicating what kind of Conv2d operation to perform +enum class ConvKind { kUnknown, kFprop, kDgrad, kWgrad, kInvalid }; + +enum class ConvModeID { kCrossCorrelation, kConvolution, kInvalid }; + +// Iterator algorithm enum in order of general performance-efficiency +enum class IteratorAlgorithmID { kNone, kAnalytic, kOptimized, kInvalid }; + +enum class EpilogueKind { + kUnknown, + kBiasAddLinearCombination, + kBiasAddLinearCombinationClamp, + kBiasAddLInearCombinationHSwish, + kBiasAddLInearCombinationHSwishClamp, + kBiasAddLInearCombinationRelu, + kBiasAddLInearCombinationReluClamp, + kConversion, + kLinearCombination, + kLinearCombinationClamp, + kLinearCombinationPlanarComplex, + kLinearCombinationRelu, + kLinearCombinationSigmoid, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct MathInstructionDescription { + /// Shape of the target math instruction + cutlass::gemm::GemmCoord instruction_shape; + + /// Describes the data type of the internal accumulator + NumericTypeID element_accumulator; + + /// Classification of math instruction + OpcodeClassID opcode_class; + + /// Type of math operation performed + MathOperationID math_operation; + + // + // Methods + // + + MathInstructionDescription( + cutlass::gemm::GemmCoord instruction_shape = + cutlass::gemm::GemmCoord(), + NumericTypeID element_accumulator = NumericTypeID::kInvalid, + OpcodeClassID opcode_class = OpcodeClassID::kInvalid, + MathOperationID math_operation = MathOperationID::kMultiplyAdd) + : instruction_shape(instruction_shape), + element_accumulator(element_accumulator), + opcode_class(opcode_class), + math_operation(math_operation) {} + + // Equality operator + inline bool operator==(MathInstructionDescription const& rhs) const { + return ((instruction_shape == rhs.instruction_shape) && + (element_accumulator == rhs.element_accumulator) && + (opcode_class == rhs.opcode_class) && + (math_operation == rhs.math_operation)); + } + + // Inequality operator + inline bool operator!=(MathInstructionDescription const& rhs) const { + return !(*this == rhs); + } +}; + +/// Structure describing the tiled structure of a GEMM-like computation +struct TileDescription { + /// Describes the shape of a threadblock (in elements) + cutlass::gemm::GemmCoord threadblock_shape; + + /// Describes the number of pipeline stages in the threadblock-scoped + /// mainloop + int threadblock_stages; + + /// Number of warps in each logical dimension + cutlass::gemm::GemmCoord warp_count; + + /// Core math instruction + MathInstructionDescription math_instruction; + + /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the + /// operation. + int minimum_compute_capability; + + /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the + /// operation. + int maximum_compute_capability; + + // + // Methods + // + + TileDescription( + cutlass::gemm::GemmCoord threadblock_shape = + cutlass::gemm::GemmCoord(), + int threadblock_stages = 0, + cutlass::gemm::GemmCoord warp_count = cutlass::gemm::GemmCoord(), + MathInstructionDescription math_instruction = + MathInstructionDescription(), + int minimum_compute_capability = 0, + int maximum_compute_capability = 0) + : threadblock_shape(threadblock_shape), + threadblock_stages(threadblock_stages), + warp_count(warp_count), + math_instruction(math_instruction), + minimum_compute_capability(minimum_compute_capability), + maximum_compute_capability(maximum_compute_capability) {} + + // Equality operator + inline bool operator==(TileDescription const& rhs) const { + return ((threadblock_shape == rhs.threadblock_shape) && + (threadblock_stages == rhs.threadblock_stages) && + (warp_count == rhs.warp_count) && + (math_instruction == rhs.math_instruction) && + (minimum_compute_capability == + rhs.minimum_compute_capability) && + (maximum_compute_capability == rhs.maximum_compute_capability)); + } + + // Inequality operator + inline bool operator!=(TileDescription const& rhs) const { + return !(*this == rhs); + } +}; + +/// High-level description of an operation +struct OperationDescription { + /// Unique identifier describing the operation + char const* name; + + /// Operation provider + Provider provider; + + /// Kind of operation + OperationKind kind; + + /// Describes the tiled structure of a GEMM-like computation + TileDescription tile_description; + + // + // Methods + // + OperationDescription( + char const* name = "unknown", + OperationKind kind = OperationKind::kInvalid, + TileDescription const& tile_description = TileDescription()) + : name(name), kind(kind), tile_description(tile_description) {} +}; + +/// Structure describing the properties of a tensor +struct TensorDescription { + /// Numeric type of an individual element + NumericTypeID element; + + /// Enumerant identifying the layout function for the tensor + LayoutTypeID layout; + + /// Alignment restriction on pointers, strides, and extents + int alignment; + + /// log2() of the maximum extent of each dimension + int log_extent_range; + + /// log2() of the maximum value each relevant stride may have + int log_stride_range; + + // + // Methods + // + + TensorDescription(NumericTypeID element = NumericTypeID::kInvalid, + LayoutTypeID layout = LayoutTypeID::kInvalid, + int alignment = 1, int log_extent_range = 24, + int log_stride_range = 24) + : element(element), + layout(layout), + alignment(alignment), + log_extent_range(log_extent_range), + log_stride_range(log_stride_range) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct GemmDescription : public OperationDescription { + GemmKind gemm_kind; + + TensorDescription A; + TensorDescription B; + TensorDescription C; + + int stages; + SplitKMode split_k_mode; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct GemmArguments { + /// GEMM problem size + gemm::GemmCoord problem_size; + + /// Device pointers to input and output matrices + void const* A; + void const* B; + void const* C; + void* D; + + /// Leading dimensions of input and output matrices + int64_t lda; + int64_t ldb; + int64_t ldc; + int64_t ldd; + + /// Number of partitions of K dimension + int split_k_slices; + + /// Host or device pointers to epilogue scalars, note that these pointers + /// will be interpreted as ElementCompute* in method `op->run(args)`, a + /// different dtype here results in undefined epilogue behaviors + void const* alpha; + void const* beta; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct ConvolutionDescription : public OperationDescription { + conv::Operator conv_op; + + TensorDescription src; + TensorDescription filter; + TensorDescription dst; + TensorDescription bias; + + conv::ConvType convolution_type; + ArchTagID arch_tag; + + epilogue::EpilogueType epilogue_type; + int epilogue_count; + + ThreadblockSwizzleID threadblock_swizzle; + + bool need_load_from_const_mem; + conv::ImplicitGemmMode gemm_mode; + bool without_shared_load; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct ConvolutionArguments { + /// Problem size + conv::Conv2dProblemSize problem_size; + + /// Device pointers to input and output tensors + void const* src; + void const* filter; + void const* bias; + void const* z; + void* dst; + + /// Host or device pointers to epilogue scalars, note that these pointers + /// will be interpreted as ElementCompute* in method `op->run(args)`, a + /// different dtype here results in undefined epilogue behaviors + void const* alpha; + void const* beta; + void const* gamma; + void const* delta; + void const* theta; + void const* threshold; + void const* scale; + + /// Host pointer to extra param struct + void const* extra_param; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Base class for all operations +class Operation { +public: + virtual ~Operation() {} + + virtual OperationDescription const& description() const = 0; + + virtual Status run(void const* arguments, void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const = 0; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/dnn/src/cuda/cutlass/library_internal.h b/dnn/src/cuda/cutlass/library_internal.h new file mode 100644 index 00000000..a6bfb01c --- /dev/null +++ b/dnn/src/cuda/cutlass/library_internal.h @@ -0,0 +1,580 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + *modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + *notice, this list of conditions and the following disclaimer in the + *documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its + *contributors may be used to endorse or promote products derived from this + *software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, + *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING + *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/** + * \file dnn/src/cuda/cutlass/library_internal.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wreorder" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wunused-parameter" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/complex.h" +#include "cutlass/convolution/threadblock/threadblock_swizzle.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#pragma GCC diagnostic pop + +#include "src/cuda/cutlass/arch_mappings.h" +#include "src/cuda/cutlass/library.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct NumericTypeMap; + +template <> +struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kB1; +}; + +template <> +struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS4; +}; + +template <> +struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS8; +}; + +template <> +struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS16; +}; + +template <> +struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS32; +}; + +template <> +struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS64; +}; + +template <> +struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU4; +}; + +template <> +struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU8; +}; + +template <> +struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU16; +}; + +template <> +struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU32; +}; + +template <> +struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU64; +}; + +template <> +struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF16; +}; + +template <> +struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF32; +}; + +template <> +struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF64; +}; + +template <> +struct NumericTypeMap> { + static NumericTypeID const kId = NumericTypeID::kCF16; +}; + +template <> +struct NumericTypeMap> { + static NumericTypeID const kId = NumericTypeID::kCF32; +}; + +template <> +struct NumericTypeMap> { + static NumericTypeID const kId = NumericTypeID::kCF64; +}; + +template <> +struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kBF16; +}; + +template <> +struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kTF32; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kInvalid; +}; + +template <> +struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAdd; +}; + +template <> +struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddFastBF16; +}; + +template <> +struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddFastF16; +}; + +template <> +struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddSaturate; +}; + +template <> +struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddComplex; +}; + +template <> +struct MathOperationMap { + static MathOperationID const kId = + MathOperationID::kMultiplyAddGaussianComplex; +}; + +template <> +struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kXorPopc; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct LayoutMap; + +template <> +struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajor; +}; + +template <> +struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kRowMajor; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK2; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK2; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK4; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK4; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK16; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK16; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK32; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK32; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK64; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK64; +}; + +template <> +struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kTensorNCHW; +}; + +template <> +struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kTensorNHWC; +}; + +template <> +struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kTensorNDHWC; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorNC4HW4; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorNC8HW8; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorNC16HW16; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorNC32HW32; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorNC64HW64; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorC4RSK4; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorC8RSK8; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorC16RSK16; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorC32RSK32; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorC64RSK64; +}; + +template <> +struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorK4RSC4; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct OpcodeClassMap; + +template <> +struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kSimt; +}; + +template <> +struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kTensorOp; +}; + +template <> +struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kWmmaTensorOp; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ArchTagMap; + +template <> +struct ArchTagMap { + static ArchTagID const kId = ArchTagID::kSm50; +}; + +template <> +struct ArchTagMap { + static ArchTagID const kId = ArchTagID::kSm60; +}; + +template <> +struct ArchTagMap { + static ArchTagID const kId = ArchTagID::kSm61; +}; + +template <> +struct ArchTagMap { + static ArchTagID const kId = ArchTagID::kSm70; +}; + +template <> +struct ArchTagMap { + static ArchTagID const kId = ArchTagID::kSm72; +}; + +template <> +struct ArchTagMap { + static ArchTagID const kId = ArchTagID::kSm75; +}; + +template <> +struct ArchTagMap { + static ArchTagID const kId = ArchTagID::kSm80; +}; + +template <> +struct ArchTagMap { + static ArchTagID const kId = ArchTagID::kSm86; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ComplexTransformMap; + +template <> +struct ComplexTransformMap { + static cutlass::library::ComplexTransform const kId = + cutlass::library::ComplexTransform::kNone; +}; + +template <> +struct ComplexTransformMap { + static cutlass::library::ComplexTransform const kId = + cutlass::library::ComplexTransform::kConjugate; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ConvModeMap; + +template <> +struct ConvModeMap { + static ConvModeID const kId = ConvModeID::kCrossCorrelation; +}; + +template <> +struct ConvModeMap { + static ConvModeID const kId = ConvModeID::kConvolution; +}; + +template +struct ConvKindMap; + +template <> +struct ConvKindMap { + static ConvKind const kId = ConvKind::kFprop; +}; + +template <> +struct ConvKindMap { + static ConvKind const kId = ConvKind::kDgrad; +}; + +template <> +struct ConvKindMap { + static ConvKind const kId = ConvKind::kWgrad; +}; + +template +struct IteratorAlgorithmMap; + +template <> +struct IteratorAlgorithmMap { + static IteratorAlgorithmID const kId = IteratorAlgorithmID::kAnalytic; +}; + +template <> +struct IteratorAlgorithmMap { + static IteratorAlgorithmID const kId = IteratorAlgorithmID::kOptimized; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ThreadblockSwizzleMap; + +template +struct ThreadblockSwizzleMap< + gemm::threadblock::GemmIdentityThreadblockSwizzle> { + static ThreadblockSwizzleID const kId = ThreadblockSwizzleID::kGemmIdentity; +}; + +template <> +struct ThreadblockSwizzleMap< + gemm::threadblock::GemmHorizontalThreadblockSwizzle> { + static ThreadblockSwizzleID const kId = + ThreadblockSwizzleID::kGemmHorizontal; +}; + +template <> +struct ThreadblockSwizzleMap< + gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle> { + static ThreadblockSwizzleID const kId = + ThreadblockSwizzleID::kGemmBatchedIdentity; +}; + +template +struct ThreadblockSwizzleMap< + gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle> { + static ThreadblockSwizzleID const kId = + ThreadblockSwizzleID::kGemmSplitKIdentity; +}; + +template <> +struct ThreadblockSwizzleMap< + gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle> { + static ThreadblockSwizzleID const kId = + ThreadblockSwizzleID::kGemmSplitKHorizontal; +}; + +template <> +struct ThreadblockSwizzleMap< + gemm::threadblock::GemvBatchedStridedThreadblockDefaultSwizzle> { + static ThreadblockSwizzleID const kId = + ThreadblockSwizzleID::kGemvBatchedStridedDefault; +}; + +template <> +struct ThreadblockSwizzleMap< + gemm::threadblock::GemvBatchedStridedThreadblockReductionSwizzle> { + static ThreadblockSwizzleID const kId = + ThreadblockSwizzleID::kGemvBatchedStridedReduction; +}; + +template <> +struct ThreadblockSwizzleMap< + conv::threadblock::ConvolutionFpropCxRSKxThreadblockSwizzle> { + static ThreadblockSwizzleID const kId = + ThreadblockSwizzleID::kConvolutionFpropCxRSKx; +}; + +template <> +struct ThreadblockSwizzleMap< + conv::threadblock::ConvolutionDgradCxRSKxThreadblockSwizzle> { + static ThreadblockSwizzleID const kId = + ThreadblockSwizzleID::kConvolutionDgradCxRSKx; +}; + +template <> +struct ThreadblockSwizzleMap< + conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle> { + static ThreadblockSwizzleID const kId = + ThreadblockSwizzleID::kConvolutionFpropNCxHWx; +}; + +template <> +struct ThreadblockSwizzleMap< + conv::threadblock::ConvolutionFpropTransThreadblockSwizzle> { + static ThreadblockSwizzleID const kId = + ThreadblockSwizzleID::kConvolutionFpropTrans; +}; + +template <> +struct ThreadblockSwizzleMap< + conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle> { + static ThreadblockSwizzleID const kId = + ThreadblockSwizzleID::kConvolutionDgradNCxHWx; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +TensorDescription make_TensorDescription(int alignment = 1) { + TensorDescription desc; + + desc.element = NumericTypeMap::kId; + desc.layout = LayoutMap::kId; + desc.alignment = alignment; + desc.log_extent_range = + int(sizeof(typename Layout::TensorCoord::Index) - 1) * 8; + desc.log_stride_range = int(sizeof(typename Layout::Stride::Index) - 1) * 8; + + return desc; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/dnn/src/cuda/cutlass/manifest.cpp b/dnn/src/cuda/cutlass/manifest.cpp new file mode 100644 index 00000000..2dd5ac85 --- /dev/null +++ b/dnn/src/cuda/cutlass/manifest.cpp @@ -0,0 +1,96 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + *modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + *notice, this list of conditions and the following disclaimer in the + *documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its + *contributors may be used to endorse or promote products derived from this + *software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, + *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING + *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/** + * \file dnn/src/cuda/cutlass/manifest.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include + +#include "src/cuda/cutlass/manifest.h" + +namespace cutlass { +namespace library { + +////////////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Top-level initialization +Status Manifest::initialize() { + if (!operations_.empty()) { + operations_.clear(); + } + + // initialize procedurally generated cutlass op in manifest object + initialize_all(*this); + + return Status::kSuccess; +} + +/// Used for initialization +void Manifest::reserve(size_t operation_count) { + operations_.reserve(operation_count); +} + +/// Graceful shutdown +Status Manifest::release() { + operations_.clear(); + return Status::kSuccess; +} + +/// Appends an operation and takes ownership +void Manifest::append(Operation* operation_ptr) { + operations_.emplace_back(operation_ptr); +} + +/// Returns an iterator to the first operation +OperationVector const& Manifest::operations() const { + return operations_; +} + +/// Returns a const iterator +OperationVector::const_iterator Manifest::begin() const { + return operations_.begin(); +} + +/// Returns a const iterator +OperationVector::const_iterator Manifest::end() const { + return operations_.end(); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/dnn/src/cuda/cutlass/manifest.h b/dnn/src/cuda/cutlass/manifest.h new file mode 100644 index 00000000..396074f3 --- /dev/null +++ b/dnn/src/cuda/cutlass/manifest.h @@ -0,0 +1,108 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + *modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + *notice, this list of conditions and the following disclaimer in the + *documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its + *contributors may be used to endorse or promote products derived from this + *software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, + *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING + *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/** + * \file dnn/src/cuda/cutlass/manifest.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include +#include +#include + +#include "src/cuda/cutlass/library.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Forward declaration +class Manifest; + +// init and insert all cutlass gemm operations in manifest object (procedurally +// generated using generator.py) +void initialize_all(Manifest& manifest); + +///////////////////////////////////////////////////////////////////////////////////////////////////////// + +/// List of operations +using OperationVector = std::vector>; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Manifest of CUTLASS Library +class Manifest { +private: + /// Operation provider + Provider provider_; + + /// Global list of operations + OperationVector operations_; + +public: + Manifest(Provider provider = library::Provider::kCUTLASS) + : provider_(provider) {} + + /// Top-level initialization + Status initialize(); + + /// Used for initialization + void reserve(size_t operation_count); + + /// Graceful shutdown + Status release(); + + /// Appends an operation and takes ownership + void append(Operation* operation_ptr); + + /// Returns an iterator to the first operation + OperationVector const& operations() const; + + /// Returns a const iterator + OperationVector::const_iterator begin() const; + + /// Returns a const iterator + OperationVector::const_iterator end() const; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/dnn/src/cuda/cutlass/operation_table.cpp b/dnn/src/cuda/cutlass/operation_table.cpp new file mode 100644 index 00000000..a4858fe8 --- /dev/null +++ b/dnn/src/cuda/cutlass/operation_table.cpp @@ -0,0 +1,179 @@ +/*************************************************************************************************** + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + *modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + *notice, this list of conditions and the following disclaimer in the + *documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its + *contributors may be used to endorse or promote products derived from this + *software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, + *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING + *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/** + * \file dnn/src/cuda/cutlass/operation_table.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/common/utils.h" +#include "src/cuda/cutlass/operation_table.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +GemmKey get_gemm_key_from_desc(const GemmDescription& desc) { + GemmKey key; + + key.element_A = desc.A.element; + key.layout_A = desc.A.layout; + key.element_B = desc.B.element; + key.layout_B = desc.B.layout; + key.element_C = desc.C.element; + key.layout_C = desc.C.layout; + + key.threadblock_shape_m = desc.tile_description.threadblock_shape.m(); + key.threadblock_shape_n = desc.tile_description.threadblock_shape.n(); + key.threadblock_shape_k = desc.tile_description.threadblock_shape.k(); + + key.warp_shape_m = desc.tile_description.threadblock_shape.m() / + desc.tile_description.warp_count.m(); + key.warp_shape_n = desc.tile_description.threadblock_shape.n() / + desc.tile_description.warp_count.n(); + key.warp_shape_k = desc.tile_description.threadblock_shape.k() / + desc.tile_description.warp_count.k(); + + key.instruction_shape_m = + desc.tile_description.math_instruction.instruction_shape.m(); + key.instruction_shape_n = + desc.tile_description.math_instruction.instruction_shape.n(); + key.instruction_shape_k = + desc.tile_description.math_instruction.instruction_shape.k(); + + key.stages = desc.stages; + key.split_k_mode = desc.split_k_mode; + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +ConvolutionKey get_convolution_key_from_desc( + const ConvolutionDescription& desc) { + ConvolutionKey key; + + key.conv_op = desc.conv_op; + + key.element_src = desc.src.element; + key.layout_src = desc.src.layout; + key.element_filter = desc.filter.element; + key.layout_filter = desc.filter.layout; + key.element_dst = desc.dst.element; + key.layout_dst = desc.dst.layout; + key.element_bias = desc.bias.element; + key.layout_bias = desc.bias.layout; + + key.convolution_type = desc.convolution_type; + + key.threadblock_shape_m = desc.tile_description.threadblock_shape.m(); + key.threadblock_shape_n = desc.tile_description.threadblock_shape.n(); + key.threadblock_shape_k = desc.tile_description.threadblock_shape.k(); + + key.warp_shape_m = desc.tile_description.threadblock_shape.m() / + desc.tile_description.warp_count.m(); + key.warp_shape_n = desc.tile_description.threadblock_shape.n() / + desc.tile_description.warp_count.n(); + key.warp_shape_k = desc.tile_description.threadblock_shape.k() / + desc.tile_description.warp_count.k(); + + key.instruction_shape_m = + desc.tile_description.math_instruction.instruction_shape.m(); + key.instruction_shape_n = + desc.tile_description.math_instruction.instruction_shape.n(); + key.instruction_shape_k = + desc.tile_description.math_instruction.instruction_shape.k(); + + key.epilogue_type = desc.epilogue_type; + + key.stages = desc.tile_description.threadblock_stages; + key.need_load_from_const_mem = desc.need_load_from_const_mem; + key.without_shared_load = desc.without_shared_load; + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +void OperationTable::append(Manifest const& manifest) { + // Insert operations into appropriate data structure + for (auto const& operation : manifest) { + OperationDescription const& desc = operation->description(); + + // insert all gemm operations into operation table + if (desc.kind == OperationKind::kGemm) { + GemmKey key = get_gemm_key_from_desc( + static_cast(desc)); + gemm_operations[key].push_back(operation.get()); + } + + // insert all conv operations into operation table + if (desc.kind == OperationKind::kConvolution) { + ConvolutionKey key = get_convolution_key_from_desc( + static_cast(desc)); + convolution_operations[key].push_back(operation.get()); + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +Operation const* OperationTable::find_op(GemmKey const& key) const { + megdnn_assert(gemm_operations.count(key) > 0, + "key not found in cutlass operation table"); + auto const& ops = gemm_operations.at(key); + megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", + ops.size()); + return ops[0]; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +Operation const* OperationTable::find_op(ConvolutionKey const& key) const { + megdnn_assert(convolution_operations.count(key) > 0, + "key not found in cutlass operation table"); + auto const& ops = convolution_operations.at(key); + megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", + ops.size()); + return ops[0]; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/dnn/src/cuda/cutlass/operation_table.h b/dnn/src/cuda/cutlass/operation_table.h new file mode 100644 index 00000000..6190ff56 --- /dev/null +++ b/dnn/src/cuda/cutlass/operation_table.h @@ -0,0 +1,334 @@ +/*************************************************************************************************** + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + *modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + *notice, this list of conditions and the following disclaimer in the + *documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its + *contributors may be used to endorse or promote products derived from this + *software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, + *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING + *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/** + * \file dnn/src/cuda/cutlass/operation_table.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include + +#include "src/common/hash_ct.h" +#include "src/cuda/cutlass/manifest.h" +#include "src/cuda/cutlass/util.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +class Hash { +public: + Hash() : m_val(0) {} + + Hash& update(const void* ptr, size_t len) { + m_val += megdnn::XXHash64CT::hash((const char*)ptr, len, 123456); + return *this; + } + + uint64_t digest() const { return m_val; } + +private: + uint64_t m_val; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Data Structures for GemmOperationMap +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct GemmKey { + NumericTypeID element_A; + LayoutTypeID layout_A; + NumericTypeID element_B; + LayoutTypeID layout_B; + NumericTypeID element_C; + LayoutTypeID layout_C; + + int threadblock_shape_m; + int threadblock_shape_n; + int threadblock_shape_k; + + int warp_shape_m; + int warp_shape_n; + int warp_shape_k; + + int instruction_shape_m; + int instruction_shape_n; + int instruction_shape_k; + + int stages; + SplitKMode split_k_mode; + + inline bool operator==(GemmKey const& rhs) const { + return (element_A == rhs.element_A) && (layout_A == rhs.layout_A) && + (element_B == rhs.element_B) && (layout_B == rhs.layout_B) && + (element_C == rhs.element_C) && (layout_C == rhs.layout_C) && + (threadblock_shape_m == rhs.threadblock_shape_m) && + (threadblock_shape_n == rhs.threadblock_shape_n) && + (threadblock_shape_k == rhs.threadblock_shape_k) && + (warp_shape_m == rhs.warp_shape_m) && + (warp_shape_n == rhs.warp_shape_n) && + (warp_shape_k == rhs.warp_shape_k) && + (instruction_shape_m == rhs.instruction_shape_m) && + (instruction_shape_n == rhs.instruction_shape_n) && + (instruction_shape_k == rhs.instruction_shape_k) && + (stages == rhs.stages) && (split_k_mode == rhs.split_k_mode); + } + + inline bool operator!=(GemmKey const& rhs) const { return !(*this == rhs); } + + inline std::string str() const { + auto tuple_to_str = [](int m, int n, int k) -> std::string { + return std::to_string(m) + " x " + std::to_string(n) + " x " + + std::to_string(k); + }; + + std::string threadblock_shape_str = tuple_to_str( + threadblock_shape_m, threadblock_shape_n, threadblock_shape_k); + std::string warp_shape_str = + tuple_to_str(warp_shape_m, warp_shape_n, warp_shape_k); + std::string instruction_shape_str = tuple_to_str( + instruction_shape_m, instruction_shape_n, instruction_shape_k); + + return std::string("{") + "\n element_A: " + to_string(element_A) + + "\n layout_A: " + to_string(layout_A) + + "\n element_B: " + to_string(element_B) + + "\n layout_B: " + to_string(layout_B) + + "\n element_C: " + to_string(element_C) + + "\n layout_C: " + to_string(layout_C) + + "\n threadblock_shape: " + threadblock_shape_str + + "\n warp_shape: " + warp_shape_str + + "\n instruction_shape: " + instruction_shape_str + + "\n stages: " + std::to_string(stages) + + "\n split_k_mode: " + to_string(split_k_mode) + "\n}"; + } +}; + +struct GemmKeyHasher { + inline size_t operator()(GemmKey const& key) const { + return Hash() + .update(&key.element_A, sizeof(key.element_A)) + .update(&key.layout_A, sizeof(key.layout_A)) + .update(&key.element_B, sizeof(key.element_B)) + .update(&key.layout_B, sizeof(key.layout_B)) + .update(&key.element_C, sizeof(key.element_C)) + .update(&key.layout_C, sizeof(key.layout_C)) + .update(&key.threadblock_shape_m, + sizeof(key.threadblock_shape_m)) + .update(&key.threadblock_shape_n, + sizeof(key.threadblock_shape_n)) + .update(&key.threadblock_shape_k, + sizeof(key.threadblock_shape_k)) + .update(&key.warp_shape_m, sizeof(key.warp_shape_m)) + .update(&key.warp_shape_n, sizeof(key.warp_shape_n)) + .update(&key.warp_shape_k, sizeof(key.warp_shape_k)) + .update(&key.stages, sizeof(key.stages)) + .update(&key.split_k_mode, sizeof(key.split_k_mode)) + .digest(); + } +}; + +using GemmOperationMap = + std::unordered_map, + GemmKeyHasher>; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Data Structures for ConvolutionOperationMap +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct ConvolutionKey { + conv::Operator conv_op; + + library::NumericTypeID element_src; + library::LayoutTypeID layout_src; + library::NumericTypeID element_filter; + library::LayoutTypeID layout_filter; + library::NumericTypeID element_dst; + library::LayoutTypeID layout_dst; + library::NumericTypeID element_bias; + library::LayoutTypeID layout_bias; + + conv::ConvType convolution_type; + + int threadblock_shape_m; + int threadblock_shape_n; + int threadblock_shape_k; + + int warp_shape_m; + int warp_shape_n; + int warp_shape_k; + + int instruction_shape_m; + int instruction_shape_n; + int instruction_shape_k; + + epilogue::EpilogueType epilogue_type; + int stages; + bool need_load_from_const_mem; + bool without_shared_load; + + inline bool operator==(ConvolutionKey const& rhs) const { + return (conv_op == rhs.conv_op) && (element_src == rhs.element_src) && + (layout_src == rhs.layout_src) && + (element_filter == rhs.element_filter) && + (layout_filter == rhs.layout_filter) && + (element_dst == rhs.element_dst) && + (layout_dst == rhs.layout_dst) && + (element_bias == rhs.element_bias) && + (layout_bias == rhs.layout_bias) && + (convolution_type == rhs.convolution_type) && + (threadblock_shape_m == rhs.threadblock_shape_m) && + (threadblock_shape_n == rhs.threadblock_shape_n) && + (threadblock_shape_k == rhs.threadblock_shape_k) && + (warp_shape_m == rhs.warp_shape_m) && + (warp_shape_n == rhs.warp_shape_n) && + (warp_shape_k == rhs.warp_shape_k) && + (instruction_shape_m == rhs.instruction_shape_m) && + (instruction_shape_n == rhs.instruction_shape_n) && + (instruction_shape_k == rhs.instruction_shape_k) && + (epilogue_type == rhs.epilogue_type) && (stages == rhs.stages) && + (need_load_from_const_mem == rhs.need_load_from_const_mem) && + (without_shared_load == rhs.without_shared_load); + } + + inline bool operator!=(ConvolutionKey const& rhs) const { + return !(*this == rhs); + } + + inline std::string str() const { + auto tuple_to_str = [](int m, int n, int k) -> std::string { + return std::to_string(m) + " x " + std::to_string(n) + " x " + + std::to_string(k); + }; + + std::string threadblock_shape_str = tuple_to_str( + threadblock_shape_m, threadblock_shape_n, threadblock_shape_k); + std::string warp_shape_str = + tuple_to_str(warp_shape_m, warp_shape_n, warp_shape_k); + std::string instruction_shape_str = tuple_to_str( + instruction_shape_m, instruction_shape_n, instruction_shape_k); + + return std::string("{") + "\n conv_op: " + to_string(conv_op) + + "\n element_src: " + to_string(element_src) + + "\n layout_src: " + to_string(layout_src) + + "\n element_filter: " + to_string(element_filter) + + "\n layout_filter: " + to_string(layout_filter) + + "\n element_dst: " + to_string(element_dst) + + "\n layout_dst: " + to_string(layout_dst) + + "\n element_bias: " + to_string(element_bias) + + "\n layout_bias: " + to_string(layout_bias) + + "\n convolution_type: " + to_string(convolution_type) + + "\n threadblock_shape: " + threadblock_shape_str + + "\n warp_shape: " + warp_shape_str + + "\n instruction_shape: " + instruction_shape_str + + "\n epilogue_type: " + to_string(epilogue_type) + + "\n stages: " + std::to_string(stages) + + "\n need_load_from_const_mem: " + + to_string(need_load_from_const_mem) + + "\n without_shared_load: " + to_string(without_shared_load) + + "\n}"; + } +}; + +struct ConvolutionKeyHasher { + inline size_t operator()(ConvolutionKey const& key) const { + return Hash() + .update(&key.conv_op, sizeof(key.conv_op)) + .update(&key.conv_op, sizeof(key.conv_op)) + .update(&key.element_src, sizeof(key.element_src)) + .update(&key.layout_src, sizeof(key.layout_src)) + .update(&key.element_filter, sizeof(key.element_filter)) + .update(&key.layout_filter, sizeof(key.layout_filter)) + .update(&key.element_dst, sizeof(key.element_dst)) + .update(&key.layout_dst, sizeof(key.layout_dst)) + .update(&key.element_bias, sizeof(key.element_bias)) + .update(&key.layout_bias, sizeof(key.layout_bias)) + .update(&key.convolution_type, sizeof(key.convolution_type)) + .update(&key.threadblock_shape_m, + sizeof(key.threadblock_shape_m)) + .update(&key.threadblock_shape_n, + sizeof(key.threadblock_shape_n)) + .update(&key.threadblock_shape_k, + sizeof(key.threadblock_shape_k)) + .update(&key.warp_shape_m, sizeof(key.warp_shape_m)) + .update(&key.warp_shape_n, sizeof(key.warp_shape_n)) + .update(&key.warp_shape_k, sizeof(key.warp_shape_k)) + .update(&key.instruction_shape_m, + sizeof(key.instruction_shape_m)) + .update(&key.instruction_shape_n, + sizeof(key.instruction_shape_n)) + .update(&key.instruction_shape_k, + sizeof(key.instruction_shape_k)) + .update(&key.epilogue_type, sizeof(key.epilogue_type)) + .update(&key.stages, sizeof(key.stages)) + .update(&key.need_load_from_const_mem, + sizeof(key.need_load_from_const_mem)) + .update(&key.without_shared_load, + sizeof(key.without_shared_load)) + .digest(); + } +}; + +using ConvolutionOperationMap = + std::unordered_map, + ConvolutionKeyHasher>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Table of cutlass::library::Operation instances +class OperationTable { +public: + /// Map of all operations of type kGemm + GemmOperationMap gemm_operations; + + /// Map of all operations of type kConvolution + ConvolutionOperationMap convolution_operations; + +public: + void append(Manifest const& manifest); + + Operation const* find_op(GemmKey const& key) const; + + Operation const* find_op(ConvolutionKey const& key) const; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/dnn/src/cuda/cutlass/singleton.cu b/dnn/src/cuda/cutlass/singleton.cu new file mode 100644 index 00000000..614a568f --- /dev/null +++ b/dnn/src/cuda/cutlass/singleton.cu @@ -0,0 +1,72 @@ +/*************************************************************************************************** + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + *modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + *notice, this list of conditions and the following disclaimer in the + *documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its + *contributors may be used to endorse or promote products derived from this + *software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, + *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING + *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/** + * \file dnn/src/cuda/cutlass/singleton.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include + +#include "src/cuda/cutlass/singleton.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +static std::unique_ptr instance; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +Singleton::Singleton() { + manifest.initialize(); + + operation_table.append(manifest); +} + +Singleton const& Singleton::get() { + if (!instance.get()) { + instance.reset(new Singleton); + } + return *instance.get(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/dnn/src/cuda/cutlass/singleton.h b/dnn/src/cuda/cutlass/singleton.h new file mode 100644 index 00000000..c67ad528 --- /dev/null +++ b/dnn/src/cuda/cutlass/singleton.h @@ -0,0 +1,70 @@ +/*************************************************************************************************** + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + *modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + *notice, this list of conditions and the following disclaimer in the + *documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its + *contributors may be used to endorse or promote products derived from this + *software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, + *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING + *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/** + * \file dnn/src/cuda/cutlass/singleton.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "src/cuda/cutlass/operation_table.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Singleton instance stores a Manifest and Operation table +class Singleton { +public: + /// Manifest object + Manifest manifest; + + /// Operation table referencing the Manifest + OperationTable operation_table; + +public: + Singleton(); + + static Singleton const& get(); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/dnn/src/cuda/cutlass/util.cu b/dnn/src/cuda/cutlass/util.cu new file mode 100644 index 00000000..11e1d301 --- /dev/null +++ b/dnn/src/cuda/cutlass/util.cu @@ -0,0 +1,1600 @@ +/*************************************************************************************************** + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + *modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + *notice, this list of conditions and the following disclaimer in the + *documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its + *contributors may be used to endorse or promote products derived from this + *software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, + *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING + *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/** + * \file dnn/src/cuda/cutlass/util.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#if __CUDACC_VER_MAJOR__ > 9 || \ + (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) + +#include +#include + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wreorder" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wunused-parameter" + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "src/cuda/cutlass/util.h" + +#pragma GCC diagnostic pop + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + char const* text; + char const* pretty; + Provider enumerant; +} Provider_enumerants[] = { + {"none", "None", Provider::kNone}, + {"cutlass", "CUTLASS", Provider::kCUTLASS}, + {"host", "reference_host", Provider::kReferenceHost}, + {"device", "reference_device", Provider::kReferenceDevice}, + {"cublas", "cuBLAS", Provider::kCUBLAS}, + {"cudnn", "cuDNN", Provider::kCUDNN}, +}; + +/// Converts a Provider enumerant to a string +char const* to_string(Provider provider, bool pretty) { + for (auto const& possible : Provider_enumerants) { + if (provider == possible.enumerant) { + if (pretty) { + return possible.pretty; + } else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/// Parses a Provider enumerant from a string +template <> +Provider from_string(std::string const& str) { + for (auto const& possible : Provider_enumerants) { + if ((str.compare(possible.text) == 0) || + (str.compare(possible.pretty) == 0)) { + return possible.enumerant; + } + } + + return Provider::kInvalid; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + char const* text; + char const* pretty; + GemmKind enumerant; +} GemmKind_enumerants[] = { + {"gemm", "", GemmKind::kGemm}, + {"spgemm", "", GemmKind::kSparse}, + {"universal", "", GemmKind::kUniversal}, + {"planar_complex", "", GemmKind::kPlanarComplex}, + {"planar_complex_array", "", + GemmKind::kPlanarComplexArray}, +}; + +/// Converts a GemmKind enumerant to a string +char const* to_string(GemmKind type, bool pretty) { + for (auto const& possible : GemmKind_enumerants) { + if (type == possible.enumerant) { + if (pretty) { + return possible.pretty; + } else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + char const* text; + char const* pretty; + OperationKind enumerant; +} OperationKind_enumerants[] = { + {"eq_gemm", "EqGemm", OperationKind::kEqGemm}, + {"gemm", "Gemm", OperationKind::kGemm}, + {"conv2d", "Conv2d", OperationKind::kConv2d}, + {"conv3d", "Conv3d", OperationKind::kConv3d}, + {"spgemm", "SparseGemm", OperationKind::kSparseGemm}, +}; + +/// Converts a Status enumerant to a string +char const* to_string(OperationKind enumerant, bool pretty) { + for (auto const& possible : OperationKind_enumerants) { + if (enumerant == possible.enumerant) { + if (pretty) { + return possible.pretty; + } else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/// Converts a Status enumerant from a string +template <> +OperationKind from_string(std::string const& str) { + for (auto const& possible : OperationKind_enumerants) { + if ((str.compare(possible.text) == 0) || + (str.compare(possible.pretty) == 0)) { + return possible.enumerant; + } + } + + return OperationKind::kInvalid; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + char const* text; + char const* pretty; + Status enumerant; +} Status_enumerants[] = { + {"success", "Success", Status::kSuccess}, + {"misaligned_operand", "Error: misaligned operand", + Status::kErrorMisalignedOperand}, + {"invalid_problem", "Error: invalid problem", + Status::kErrorInvalidProblem}, + {"not_supported", "Error: not supported", Status::kErrorNotSupported}, + {"internal", "Error: internal", Status::kErrorInternal}}; + +/// Converts a Status enumerant to a string +char const* to_string(Status status, bool pretty) { + for (auto const& possible : Status_enumerants) { + if (status == possible.enumerant) { + if (pretty) { + return possible.pretty; + } else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/// Converts a Status enumerant from a string +template <> +Status from_string(std::string const& str) { + for (auto const& possible : Status_enumerants) { + if ((str.compare(possible.text) == 0) || + (str.compare(possible.pretty) == 0)) { + return possible.enumerant; + } + } + + return Status::kInvalid; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + char const* text; + char const* pretty; + NumericTypeID enumerant; +} NumericTypeID_enumerants[] = { + {"unknown", "", NumericTypeID::kUnknown}, + {"void", "Void", NumericTypeID::kVoid}, + {"b1", "B1", NumericTypeID::kB1}, + {"u2", "U2", NumericTypeID::kU2}, + {"u4", "U4", NumericTypeID::kU4}, + {"u8", "U8", NumericTypeID::kU8}, + {"u16", "U16", NumericTypeID::kU16}, + {"u32", "U32", NumericTypeID::kU32}, + {"u64", "U64", NumericTypeID::kU64}, + {"s2", "S2", NumericTypeID::kS2}, + {"s4", "S4", NumericTypeID::kS4}, + {"s8", "S8", NumericTypeID::kS8}, + {"s16", "S16", NumericTypeID::kS16}, + {"s32", "S32", NumericTypeID::kS32}, + {"s64", "S64", NumericTypeID::kS64}, + {"f16", "F16", NumericTypeID::kF16}, + {"bf16", "BF16", NumericTypeID::kBF16}, + {"f32", "F32", NumericTypeID::kF32}, + {"tf32", "TF32", NumericTypeID::kTF32}, + {"f64", "F64", NumericTypeID::kF64}, + {"cf16", "CF16", NumericTypeID::kCF16}, + {"cbf16", "CBF16", NumericTypeID::kCBF16}, + {"cf32", "CF32", NumericTypeID::kCF32}, + {"ctf32", "CTF32", NumericTypeID::kCTF32}, + {"cf64", "CF64", NumericTypeID::kCF64}, + {"cu2", "CU2", NumericTypeID::kCU2}, + {"cu4", "CU4", NumericTypeID::kCU4}, + {"cu8", "CU8", NumericTypeID::kCU8}, + {"cu16", "CU16", NumericTypeID::kCU16}, + {"cu32", "CU32", NumericTypeID::kCU32}, + {"cu64", "CU64", NumericTypeID::kCU64}, + {"cs2", "CS2", NumericTypeID::kCS2}, + {"cs4", "CS4", NumericTypeID::kCS4}, + {"cs8", "CS8", NumericTypeID::kCS8}, + {"cs16", "CS16", NumericTypeID::kCS16}, + {"cs32", "CS32", NumericTypeID::kCS32}, + {"cs64", "CS64", NumericTypeID::kCS64}, + {"*", "", NumericTypeID::kUnknown}}; + +/// Converts a NumericTypeID enumerant to a string +char const* to_string(NumericTypeID type, bool pretty) { + for (auto const& possible : NumericTypeID_enumerants) { + if (type == possible.enumerant) { + if (pretty) { + return possible.pretty; + } else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/// Parses a NumericTypeID enumerant from a string +template <> +NumericTypeID from_string(std::string const& str) { + for (auto const& possible : NumericTypeID_enumerants) { + if ((str.compare(possible.text) == 0) || + (str.compare(possible.pretty) == 0)) { + return possible.enumerant; + } + } + + return NumericTypeID::kInvalid; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns the size of a data type in bits +int sizeof_bits(NumericTypeID type) { + switch (type) { + case NumericTypeID::kF16: + return 16; + case NumericTypeID::kBF16: + return 16; + case NumericTypeID::kTF32: + return 32; + case NumericTypeID::kF32: + return 32; + case NumericTypeID::kF64: + return 64; + case NumericTypeID::kCF16: + return 32; + case NumericTypeID::kCBF16: + return 32; + case NumericTypeID::kCF32: + return 64; + case NumericTypeID::kCTF32: + return 64; + case NumericTypeID::kCF64: + return 128; + case NumericTypeID::kS2: + return 2; + case NumericTypeID::kS4: + return 4; + case NumericTypeID::kS8: + return 8; + case NumericTypeID::kS16: + return 16; + case NumericTypeID::kS32: + return 32; + case NumericTypeID::kS64: + return 64; + case NumericTypeID::kU2: + return 2; + case NumericTypeID::kU4: + return 4; + case NumericTypeID::kU8: + return 8; + case NumericTypeID::kU16: + return 16; + case NumericTypeID::kU32: + return 32; + case NumericTypeID::kU64: + return 64; + case NumericTypeID::kB1: + return 1; + default: + break; + } + return 0; +} + +/// Returns true if the numeric type is a complex data type or false if +/// real-valued. +bool is_complex_type(NumericTypeID type) { + switch (type) { + case NumericTypeID::kCF16: + return true; + case NumericTypeID::kCF32: + return true; + case NumericTypeID::kCF64: + return true; + case NumericTypeID::kCBF16: + return true; + case NumericTypeID::kCTF32: + return true; + default: + break; + } + return false; +} + +/// Returns the field underlying a complex valued type +NumericTypeID get_real_type(NumericTypeID type) { + switch (type) { + case NumericTypeID::kCF16: + return NumericTypeID::kF16; + case NumericTypeID::kCF32: + return NumericTypeID::kF32; + case NumericTypeID::kCF64: + return NumericTypeID::kF64; + case NumericTypeID::kCBF16: + return NumericTypeID::kBF16; + case NumericTypeID::kCTF32: + return NumericTypeID::kTF32; + default: + break; + } + return type; +} + +/// Returns true if numeric type is integer +bool is_integer_type(NumericTypeID type) { + switch (type) { + case NumericTypeID::kS2: + return true; + case NumericTypeID::kS4: + return true; + case NumericTypeID::kS8: + return true; + case NumericTypeID::kS16: + return true; + case NumericTypeID::kS32: + return true; + case NumericTypeID::kS64: + return true; + case NumericTypeID::kU2: + return true; + case NumericTypeID::kU4: + return true; + case NumericTypeID::kU8: + return true; + case NumericTypeID::kU16: + return true; + case NumericTypeID::kU32: + return true; + case NumericTypeID::kU64: + return true; + default: + break; + } + return false; +} + +/// Returns true if numeric type is signed +bool is_signed_type(NumericTypeID type) { + switch (type) { + case NumericTypeID::kF16: + return true; + case NumericTypeID::kBF16: + return true; + case NumericTypeID::kTF32: + return true; + case NumericTypeID::kF32: + return true; + case NumericTypeID::kF64: + return true; + case NumericTypeID::kS2: + return true; + case NumericTypeID::kS4: + return true; + case NumericTypeID::kS8: + return true; + case NumericTypeID::kS16: + return true; + case NumericTypeID::kS32: + return true; + case NumericTypeID::kS64: + return true; + default: + break; + } + return false; +} + +/// Returns true if numeric type is a signed integer +bool is_signed_integer(NumericTypeID type) { + return is_integer_type(type) && is_signed_type(type); +} + +/// returns true if numeric type is an unsigned integer +bool is_unsigned_integer(NumericTypeID type) { + return is_integer_type(type) && !is_signed_type(type); +} + +/// Returns true if numeric type is floating-point type +bool is_float_type(NumericTypeID type) { + switch (type) { + case NumericTypeID::kF16: + return true; + case NumericTypeID::kBF16: + return true; + case NumericTypeID::kTF32: + return true; + case NumericTypeID::kF32: + return true; + case NumericTypeID::kF64: + return true; + case NumericTypeID::kCF16: + return true; + case NumericTypeID::kCBF16: + return true; + case NumericTypeID::kCTF32: + return true; + case NumericTypeID::kCF32: + return true; + case NumericTypeID::kCF64: + return true; + default: + break; + } + return false; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + LayoutTypeID layout; + char const* alias; +} layout_aliases[] = {{LayoutTypeID::kUnknown, "unknown"}, + {LayoutTypeID::kRowMajor, "row"}, + {LayoutTypeID::kRowMajor, "t"}, + {LayoutTypeID::kColumnMajor, "column"}, + {LayoutTypeID::kColumnMajor, "col"}, + {LayoutTypeID::kColumnMajor, "n"}, + + {LayoutTypeID::kColumnMajorInterleavedK2, "nk2"}, + {LayoutTypeID::kRowMajorInterleavedK2, "tk2"}, + + {LayoutTypeID::kColumnMajorInterleavedK4, "nk4"}, + {LayoutTypeID::kRowMajorInterleavedK4, "tk4"}, + + {LayoutTypeID::kColumnMajorInterleavedK16, "nk16"}, + {LayoutTypeID::kRowMajorInterleavedK16, "tk16"}, + + {LayoutTypeID::kColumnMajorInterleavedK32, "nk32"}, + {LayoutTypeID::kRowMajorInterleavedK32, "tk32"}, + + {LayoutTypeID::kColumnMajorInterleavedK64, "nk64"}, + {LayoutTypeID::kRowMajorInterleavedK64, "tk64"}, + + {LayoutTypeID::kTensorNCHW, "nchw"}, + {LayoutTypeID::kTensorNCDHW, "ncdhw"}, + {LayoutTypeID::kTensorNHWC, "nhwc"}, + {LayoutTypeID::kTensorNDHWC, "ndhwc"}, + {LayoutTypeID::kTensorNC4HW4, "nc4hw4"}, + {LayoutTypeID::kTensorNC8HW8, "nc8hw8"}, + {LayoutTypeID::kTensorNC16HW16, "nc16hw16"}, + {LayoutTypeID::kTensorNC32HW32, "nc32hw32"}, + {LayoutTypeID::kTensorNC64HW64, "nc64hw64"}, + {LayoutTypeID::kTensorC4RSK4, "c4rsk4"}, + {LayoutTypeID::kTensorC8RSK8, "c8rsk8"}, + {LayoutTypeID::kTensorC16RSK16, "c16rsk16"}, + {LayoutTypeID::kTensorC32RSK32, "c32rsk32"}, + {LayoutTypeID::kTensorC64RSK64, "c64rsk64"}, + {LayoutTypeID::kTensorK4RSC4, "k4rsC4"}, + {LayoutTypeID::kUnknown, "*"}, + {LayoutTypeID::kInvalid, nullptr}}; + +/// Converts a LayoutTypeID enumerant to a string +char const* to_string(LayoutTypeID layout, bool pretty) { + for (auto const& alias : layout_aliases) { + if (alias.layout == layout) { + return alias.alias; + } + } + return pretty ? "Invalid" : "invalid"; +} + +/// Parses a LayoutTypeID enumerant from a string +template <> +LayoutTypeID from_string(std::string const& str) { + for (auto const& alias : layout_aliases) { + if (str.compare(alias.alias) == 0) { + return alias.layout; + } + } + return LayoutTypeID::kInvalid; +} + +/// Gets stride rank for the layout_id (static function) +int get_layout_stride_rank(LayoutTypeID layout_id) { + switch (layout_id) { + case LayoutTypeID::kColumnMajor: + return cutlass::layout::ColumnMajor::kStrideRank; + case LayoutTypeID::kRowMajor: + return cutlass::layout::RowMajor::kStrideRank; + case LayoutTypeID::kColumnMajorInterleavedK2: + return cutlass::layout::ColumnMajorInterleaved<2>::kStrideRank; + case LayoutTypeID::kRowMajorInterleavedK2: + return cutlass::layout::RowMajorInterleaved<2>::kStrideRank; + case LayoutTypeID::kColumnMajorInterleavedK4: + return cutlass::layout::ColumnMajorInterleaved<4>::kStrideRank; + case LayoutTypeID::kRowMajorInterleavedK4: + return cutlass::layout::RowMajorInterleaved<4>::kStrideRank; + case LayoutTypeID::kColumnMajorInterleavedK16: + return cutlass::layout::ColumnMajorInterleaved<16>::kStrideRank; + case LayoutTypeID::kRowMajorInterleavedK16: + return cutlass::layout::RowMajorInterleaved<16>::kStrideRank; + case LayoutTypeID::kColumnMajorInterleavedK32: + return cutlass::layout::ColumnMajorInterleaved<32>::kStrideRank; + case LayoutTypeID::kRowMajorInterleavedK32: + return cutlass::layout::RowMajorInterleaved<32>::kStrideRank; + case LayoutTypeID::kColumnMajorInterleavedK64: + return cutlass::layout::ColumnMajorInterleaved<64>::kStrideRank; + case LayoutTypeID::kRowMajorInterleavedK64: + return cutlass::layout::RowMajorInterleaved<64>::kStrideRank; + case LayoutTypeID::kTensorNCHW: + return cutlass::layout::TensorNCHW::kStrideRank; + case LayoutTypeID::kTensorNHWC: + return cutlass::layout::TensorNHWC::kStrideRank; + case LayoutTypeID::kTensorNDHWC: + return cutlass::layout::TensorNDHWC::kStrideRank; + case LayoutTypeID::kTensorNC32HW32: + return cutlass::layout::TensorNCxHWx<32>::kStrideRank; + case LayoutTypeID::kTensorNC64HW64: + return cutlass::layout::TensorNCxHWx<64>::kStrideRank; + case LayoutTypeID::kTensorC32RSK32: + return cutlass::layout::TensorCxRSKx<32>::kStrideRank; + case LayoutTypeID::kTensorC64RSK64: + return cutlass::layout::TensorCxRSKx<64>::kStrideRank; + default: + throw std::runtime_error( + "Unsupported LayoutTypeID in LayoutType::get_stride_rank"); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + char const* text; + char const* pretty; + OpcodeClassID enumerant; +} OpcodeClassID_enumerants[] = { + {"simt", "", OpcodeClassID::kSimt}, + {"tensorop", "", OpcodeClassID::kTensorOp}, + {"wmmatensorop", "", OpcodeClassID::kWmmaTensorOp}, + {"wmma", "", OpcodeClassID::kWmmaTensorOp}, +}; + +/// Converts a OpcodeClassID enumerant to a string +char const* to_string(OpcodeClassID type, bool pretty) { + for (auto const& possible : OpcodeClassID_enumerants) { + if (type == possible.enumerant) { + if (pretty) { + return possible.pretty; + } else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/// Converts a OpcodeClassID enumerant from a string +template <> +OpcodeClassID from_string(std::string const& str) { + for (auto const& possible : OpcodeClassID_enumerants) { + if ((str.compare(possible.text) == 0) || + (str.compare(possible.pretty) == 0)) { + return possible.enumerant; + } + } + + return OpcodeClassID::kInvalid; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + char const* text; + char const* pretty; + ComplexTransform enumerant; +} ComplexTransform_enumerants[] = {{"n", "none", ComplexTransform::kNone}, + {"c", "conj", ComplexTransform::kConjugate}}; + +/// Converts a ComplexTransform enumerant to a string +char const* to_string(ComplexTransform type, bool pretty) { + for (auto const& possible : ComplexTransform_enumerants) { + if (type == possible.enumerant) { + if (pretty) { + return possible.pretty; + } else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/// Converts a ComplexTransform enumerant from a string +template <> +ComplexTransform from_string(std::string const& str) { + for (auto const& possible : ComplexTransform_enumerants) { + if ((str.compare(possible.text) == 0) || + (str.compare(possible.pretty) == 0)) { + return possible.enumerant; + } + } + + return ComplexTransform::kInvalid; +} + +static struct { + char const* text; + char const* pretty; + SplitKMode enumerant; +} SplitKMode_enumerants[] = { + {"none", "", SplitKMode::kNone}, + {"serial", "", SplitKMode::kSerial}, + {"parallel", "", SplitKMode::kParallel}, +}; + +/// Converts a SplitKMode enumerant to a string +char const* to_string(SplitKMode type, bool pretty) { + for (auto const& possible : SplitKMode_enumerants) { + if (type == possible.enumerant) { + if (pretty) { + return possible.pretty; + } else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/// Converts a SplitKMode enumerant from a string +template <> +SplitKMode from_string(std::string const& str) { + for (auto const& possible : SplitKMode_enumerants) { + if ((str.compare(possible.text) == 0) || + (str.compare(possible.pretty) == 0)) { + return possible.enumerant; + } + } + + return SplitKMode::kInvalid; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + char const* text; + char const* pretty; + ConvModeID enumerant; +} ConvModeID_enumerants[] = { + {"cross", "", ConvModeID::kCrossCorrelation}, + {"conv", "", ConvModeID::kConvolution}, +}; + +/// Converts a ConvModeID enumerant to a string +char const* to_string(ConvModeID type, bool pretty) { + for (auto const& possible : ConvModeID_enumerants) { + if (type == possible.enumerant) { + if (pretty) { + return possible.pretty; + } else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/// Converts a ConvModeID enumerant from a string +template <> +ConvModeID from_string(std::string const& str) { + for (auto const& possible : ConvModeID_enumerants) { + if ((str.compare(possible.text) == 0) || + (str.compare(possible.pretty) == 0)) { + return possible.enumerant; + } + } + + return ConvModeID::kInvalid; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + char const* text; + char const* pretty; + IteratorAlgorithmID enumerant; +} IteratorAlgorithmID_enumerants[] = { + {"none", "", IteratorAlgorithmID::kNone}, + {"analytic", "", IteratorAlgorithmID::kAnalytic}, + {"optimized", "", IteratorAlgorithmID::kOptimized}, +}; + +/// Converts a ConvModeID enumerant to a string +char const* to_string(IteratorAlgorithmID type, bool pretty) { + for (auto const& possible : IteratorAlgorithmID_enumerants) { + if (type == possible.enumerant) { + if (pretty) { + return possible.pretty; + } else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/// Converts a ConvModeID enumerant from a string +template <> +IteratorAlgorithmID from_string(std::string const& str) { + for (auto const& possible : IteratorAlgorithmID_enumerants) { + if ((str.compare(possible.text) == 0) || + (str.compare(possible.pretty) == 0)) { + return possible.enumerant; + } + } + + return IteratorAlgorithmID::kInvalid; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + char const* text; + char const* pretty; + ConvKind enumerant; +} ConvKind_enumerants[] = { + {"unknown", "", ConvKind::kUnknown}, + {"fprop", "", ConvKind::kFprop}, + {"dgrad", "", ConvKind::kDgrad}, + {"wgrad", "", ConvKind::kWgrad}, +}; + +/// Converts a ConvKind enumerant to a string +char const* to_string(ConvKind type, bool pretty) { + for (auto const& possible : ConvKind_enumerants) { + if (type == possible.enumerant) { + if (pretty) { + return possible.pretty; + } else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/// Converts a ConvKind enumerant from a string +template <> +ConvKind from_string(std::string const& str) { + for (auto const& possible : ConvKind_enumerants) { + if ((str.compare(possible.text) == 0) || + (str.compare(possible.pretty) == 0)) { + return possible.enumerant; + } + } + + return ConvKind::kInvalid; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Lexical cast a string to a byte array. Returns true if cast is successful or +/// false if invalid. +bool lexical_cast(std::vector& bytes, NumericTypeID type, + std::string const& str) { + int size_bytes = sizeof_bits(type) / 8; + if (!size_bytes) { + return false; + } + + bytes.resize(size_bytes, 0); + + std::stringstream ss; + ss << str; + + switch (type) { + case NumericTypeID::kU8: { + ss >> *reinterpret_cast(bytes.data()); + } break; + case NumericTypeID::kU16: { + ss >> *reinterpret_cast(bytes.data()); + } break; + case NumericTypeID::kU32: { + ss >> *reinterpret_cast(bytes.data()); + } break; + case NumericTypeID::kU64: { + ss >> *reinterpret_cast(bytes.data()); + } break; + case NumericTypeID::kS8: { + ss >> *reinterpret_cast(bytes.data()); + } break; + case NumericTypeID::kS16: { + ss >> *reinterpret_cast(bytes.data()); + } break; + case NumericTypeID::kS32: { + ss >> *reinterpret_cast(bytes.data()); + } break; + case NumericTypeID::kS64: { + ss >> *reinterpret_cast(bytes.data()); + } break; + case NumericTypeID::kF16: { + float tmp; + ss >> tmp; + *reinterpret_cast(bytes.data()) = static_cast(tmp); + } break; + case NumericTypeID::kBF16: { + float tmp; + ss >> tmp; + *reinterpret_cast(bytes.data()) = + static_cast(tmp); + } break; + case NumericTypeID::kTF32: { + float tmp; + ss >> tmp; + *reinterpret_cast(bytes.data()) = + static_cast(tmp); + } break; + case NumericTypeID::kF32: { + ss >> *reinterpret_cast(bytes.data()); + } break; + case NumericTypeID::kF64: { + ss >> *reinterpret_cast(bytes.data()); + } break; + case NumericTypeID::kCF16: { + std::complex tmp; + ss >> tmp; + cutlass::complex* x = + reinterpret_cast*>(bytes.data()); + x->real() = static_cast(std::real(tmp)); + x->imag() = static_cast(std::imag(tmp)); + } break; + case NumericTypeID::kCBF16: { + std::complex tmp; + ss >> tmp; + cutlass::complex* x = + reinterpret_cast*>( + bytes.data()); + x->real() = static_cast(std::real(tmp)); + x->imag() = static_cast(std::imag(tmp)); + } break; + case NumericTypeID::kCF32: { + ss >> *reinterpret_cast*>(bytes.data()); + } break; + case NumericTypeID::kCTF32: { + std::complex tmp; + ss >> tmp; + cutlass::complex* x = + reinterpret_cast*>( + bytes.data()); + x->real() = static_cast(std::real(tmp)); + x->imag() = static_cast(std::imag(tmp)); + } break; + case NumericTypeID::kCF64: { + ss >> *reinterpret_cast*>(bytes.data()); + } break; + default: + return false; + } + + return true; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +std::string lexical_cast(int64_t int_value) { + std::stringstream ss; + ss << int_value; + return ss.str(); +} + +/// Lexical cast TO a string FROM a byte array. Returns true if cast is +/// successful or false if invalid. +std::string lexical_cast(std::vector& bytes, NumericTypeID type) { + size_t size_bytes = sizeof_bits(type) / 8; + + if (!size_bytes || size_bytes != bytes.size()) { + return ""; + } + + bytes.resize(size_bytes, 0); + + std::stringstream ss; + + switch (type) { + case NumericTypeID::kU8: { + ss << *reinterpret_cast(bytes.data()); + } break; + case NumericTypeID::kU16: { + ss << *reinterpret_cast(bytes.data()); + } break; + case NumericTypeID::kU32: { + ss << *reinterpret_cast(bytes.data()); + } break; + case NumericTypeID::kU64: { + ss << *reinterpret_cast(bytes.data()); + } break; + case NumericTypeID::kS8: { + ss << *reinterpret_cast(bytes.data()); + } break; + case NumericTypeID::kS16: { + ss << *reinterpret_cast(bytes.data()); + } break; + case NumericTypeID::kS32: { + ss << *reinterpret_cast(bytes.data()); + } break; + case NumericTypeID::kS64: { + ss << *reinterpret_cast(bytes.data()); + } break; + case NumericTypeID::kF16: { + float tmp = *reinterpret_cast(bytes.data()); + ss << tmp; + } break; + case NumericTypeID::kBF16: { + float tmp = *reinterpret_cast(bytes.data()); + ; + ss << tmp; + } break; + case NumericTypeID::kTF32: { + float tmp = *reinterpret_cast(bytes.data()); + ; + ss << tmp; + } break; + case NumericTypeID::kF32: { + ss << *reinterpret_cast(bytes.data()); + } break; + case NumericTypeID::kF64: { + ss << *reinterpret_cast(bytes.data()); + } break; + case NumericTypeID::kCF16: { + cutlass::complex const* x = + reinterpret_cast const*>( + bytes.data()); + + ss << float(x->real()); + + if (x->imag() != cutlass::half_t()) { + ss << "+i" << float(x->imag()); + } + } break; + case NumericTypeID::kCBF16: { + cutlass::complex const* x = + reinterpret_cast const*>( + bytes.data()); + + ss << float(x->real()); + + if (x->imag() != cutlass::bfloat16_t()) { + ss << "+i" << float(x->imag()); + } + } break; + case NumericTypeID::kCF32: { + cutlass::complex const* x = + reinterpret_cast const*>( + bytes.data()); + + ss << x->real(); + + if (x->imag() != float()) { + ss << "+i" << x->imag(); + } + } break; + case NumericTypeID::kCTF32: { + cutlass::complex const* x = + reinterpret_cast const*>( + bytes.data()); + + ss << float(x->real()); + + if (x->imag() != tfloat32_t()) { + ss << "+i" << float(x->imag()); + } + } break; + case NumericTypeID::kCF64: { + cutlass::complex const* x = + reinterpret_cast const*>( + bytes.data()); + + ss << x->real(); + + if (x->imag() != double()) { + ss << "+i" << x->imag(); + } + } break; + default: + return ""; + } + + return ss.str(); +} + +/// Casts from a signed int64 to the destination type. Returns true if +/// successful. +bool cast_from_int64(std::vector& bytes, NumericTypeID type, + int64_t src) { + int size_bytes = sizeof_bits(type) / 8; + if (!size_bytes) { + return false; + } + + bytes.resize(size_bytes, 0); + + switch (type) { + case NumericTypeID::kU8: { + *reinterpret_cast(bytes.data()) = + static_cast(src); + } break; + case NumericTypeID::kU16: { + *reinterpret_cast(bytes.data()) = + static_cast(src); + } break; + case NumericTypeID::kU32: { + *reinterpret_cast(bytes.data()) = + static_cast(src); + } break; + case NumericTypeID::kU64: { + *reinterpret_cast(bytes.data()) = + static_cast(src); + } break; + case NumericTypeID::kS8: { + *reinterpret_cast(bytes.data()) = static_cast(src); + } break; + case NumericTypeID::kS16: { + *reinterpret_cast(bytes.data()) = + static_cast(src); + } break; + case NumericTypeID::kS32: { + *reinterpret_cast(bytes.data()) = + static_cast(src); + } break; + case NumericTypeID::kS64: { + *reinterpret_cast(bytes.data()) = + static_cast(src); + } break; + case NumericTypeID::kF16: { + *reinterpret_cast(bytes.data()) = + static_cast(float(src)); + } break; + case NumericTypeID::kBF16: { + *reinterpret_cast(bytes.data()) = + static_cast(float(src)); + } break; + case NumericTypeID::kTF32: { + *reinterpret_cast(bytes.data()) = + static_cast(float(src)); + } break; + case NumericTypeID::kF32: { + *reinterpret_cast(bytes.data()) = static_cast(src); + } break; + case NumericTypeID::kF64: { + *reinterpret_cast(bytes.data()) = double(src); + } break; + case NumericTypeID::kCF16: { + cutlass::complex* x = + reinterpret_cast*>(bytes.data()); + x->real() = static_cast(float(src)); + x->imag() = static_cast(float(0)); + } break; + case NumericTypeID::kCF32: { + *reinterpret_cast*>(bytes.data()) = + cutlass::complex(float(src), float(0)); + } break; + case NumericTypeID::kCF64: { + *reinterpret_cast*>(bytes.data()) = + cutlass::complex(double(src), double(0)); + } break; + default: + return false; + } + + return true; +} + +/// Casts from an unsigned int64 to the destination type. Returns true if +/// successful. +bool cast_from_uint64(std::vector& bytes, NumericTypeID type, + uint64_t src) { + int size_bytes = sizeof_bits(type) / 8; + if (!size_bytes) { + return false; + } + + bytes.resize(size_bytes, 0); + + switch (type) { + case NumericTypeID::kU8: { + *reinterpret_cast(bytes.data()) = + static_cast(src); + } break; + case NumericTypeID::kU16: { + *reinterpret_cast(bytes.data()) = + static_cast(src); + } break; + case NumericTypeID::kU32: { + *reinterpret_cast(bytes.data()) = + static_cast(src); + } break; + case NumericTypeID::kU64: { + *reinterpret_cast(bytes.data()) = + static_cast(src); + } break; + case NumericTypeID::kS8: { + *reinterpret_cast(bytes.data()) = static_cast(src); + } break; + case NumericTypeID::kS16: { + *reinterpret_cast(bytes.data()) = + static_cast(src); + } break; + case NumericTypeID::kS32: { + *reinterpret_cast(bytes.data()) = + static_cast(src); + } break; + case NumericTypeID::kS64: { + *reinterpret_cast(bytes.data()) = + static_cast(src); + } break; + case NumericTypeID::kF16: { + *reinterpret_cast(bytes.data()) = + static_cast(float(src)); + } break; + case NumericTypeID::kBF16: { + *reinterpret_cast(bytes.data()) = + static_cast(float(src)); + } break; + case NumericTypeID::kTF32: { + *reinterpret_cast(bytes.data()) = + static_cast(float(src)); + } break; + case NumericTypeID::kF32: { + *reinterpret_cast(bytes.data()) = static_cast(src); + } break; + case NumericTypeID::kF64: { + *reinterpret_cast(bytes.data()) = double(src); + } break; + case NumericTypeID::kCF16: { + cutlass::complex* x = + reinterpret_cast*>(bytes.data()); + x->real() = static_cast(float(src)); + x->imag() = static_cast(float(0)); + } break; + case NumericTypeID::kCF32: { + *reinterpret_cast*>(bytes.data()) = + std::complex(float(src), float(0)); + } break; + case NumericTypeID::kCF64: { + *reinterpret_cast*>(bytes.data()) = + std::complex(double(src), double(0)); + } break; + default: + return false; + } + + return true; +} + +/// Lexical cast a string to a byte array. Returns true if cast is successful or +/// false if invalid. +bool cast_from_double(std::vector& bytes, NumericTypeID type, + double src) { + int size_bytes = sizeof_bits(type) / 8; + if (!size_bytes) { + return false; + } + + bytes.resize(size_bytes, 0); + + switch (type) { + case NumericTypeID::kU8: { + *reinterpret_cast(bytes.data()) = + static_cast(src); + } break; + case NumericTypeID::kU16: { + *reinterpret_cast(bytes.data()) = + static_cast(src); + } break; + case NumericTypeID::kU32: { + *reinterpret_cast(bytes.data()) = + static_cast(src); + } break; + case NumericTypeID::kU64: { + *reinterpret_cast(bytes.data()) = + static_cast(src); + } break; + case NumericTypeID::kS8: { + *reinterpret_cast(bytes.data()) = static_cast(src); + } break; + case NumericTypeID::kS16: { + *reinterpret_cast(bytes.data()) = + static_cast(src); + } break; + case NumericTypeID::kS32: { + *reinterpret_cast(bytes.data()) = + static_cast(src); + } break; + case NumericTypeID::kS64: { + *reinterpret_cast(bytes.data()) = + static_cast(src); + } break; + case NumericTypeID::kF16: { + *reinterpret_cast(bytes.data()) = + static_cast(float(src)); + } break; + case NumericTypeID::kBF16: { + *reinterpret_cast(bytes.data()) = + static_cast(float(src)); + } break; + case NumericTypeID::kTF32: { + *reinterpret_cast(bytes.data()) = + static_cast(float(src)); + } break; + case NumericTypeID::kF32: { + *reinterpret_cast(bytes.data()) = static_cast(src); + } break; + case NumericTypeID::kF64: { + *reinterpret_cast(bytes.data()) = src; + } break; + case NumericTypeID::kCF16: { + cutlass::complex* x = + reinterpret_cast*>(bytes.data()); + x->real() = static_cast(float(src)); + x->imag() = static_cast(float(0)); + } break; + case NumericTypeID::kCBF16: { + cutlass::complex* x = + reinterpret_cast*>( + bytes.data()); + x->real() = static_cast(bfloat16_t(src)); + x->imag() = static_cast(bfloat16_t(0)); + } break; + case NumericTypeID::kCF32: { + *reinterpret_cast*>(bytes.data()) = + cutlass::complex(float(src), float()); + } break; + case NumericTypeID::kCTF32: { + *reinterpret_cast*>(bytes.data()) = + cutlass::complex(tfloat32_t(src), tfloat32_t()); + } break; + case NumericTypeID::kCF64: { + *reinterpret_cast*>(bytes.data()) = + cutlass::complex(src, double()); + } break; + default: + return false; + } + + return true; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + char const* text; + char const* pretty; + conv::Operator enumerant; +} ConvOperator_enumerants[] = { + {"fprop", "Fprop", conv::Operator::kFprop}, + {"dgrad", "Dgrad", conv::Operator::kDgrad}, + {"wgrad", "Wgrad", conv::Operator::kWgrad}, +}; + +/// Converts a conv::Operator enumerant to a string +char const* to_string(conv::Operator conv_op, bool pretty) { + for (auto const& possible : ConvOperator_enumerants) { + if (conv_op == possible.enumerant) { + if (pretty) { + return possible.pretty; + } else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + char const* text; + char const* pretty; + conv::ConvType enumerant; +} ConvType_enumerants[] = { + {"convolution", "Convolution", conv::ConvType::kConvolution}, + {"batch_convolution", "BatchConvolution", + conv::ConvType::kBatchConvolution}, + {"local", "Local", conv::ConvType::kLocal}, + {"local_share", "LocalShare", conv::ConvType::kLocalShare}, +}; + +/// Converts a ConvType enumerant to a string +char const* to_string(conv::ConvType type, bool pretty) { + for (auto const& possible : ConvType_enumerants) { + if (type == possible.enumerant) { + if (pretty) { + return possible.pretty; + } else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + char const* text; + char const* pretty; + ArchTagID enumerant; +} ArchTagID_enumerants[] = { + {"sm_50", "Sm50", ArchTagID::kSm50}, + {"sm_60", "Sm60", ArchTagID::kSm60}, + {"sm_61", "Sm61", ArchTagID::kSm61}, + {"sm_70", "Sm70", ArchTagID::kSm70}, + {"sm_72", "Sm72", ArchTagID::kSm72}, + {"sm_75", "Sm75", ArchTagID::kSm75}, + {"sm_80", "Sm80", ArchTagID::kSm80}, + {"sm_86", "Sm86", ArchTagID::kSm86}, +}; + +/// Converts an ArchTagID enumerant to a string +char const* to_string(ArchTagID tag, bool pretty) { + for (auto const& possible : ArchTagID_enumerants) { + if (tag == possible.enumerant) { + if (pretty) { + return possible.pretty; + } else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + char const* text; + char const* pretty; + epilogue::EpilogueType enumerant; +} EpilogueType_enumerants[] = { + {"bias_add_linear_combination", "BiasAddLinearCombination", + epilogue::EpilogueType::kBiasAddLinearCombination}, + {"bias_add_linear_combination_clamp", "BiasAddLinearCombinationClamp", + epilogue::EpilogueType::kBiasAddLinearCombinationClamp}, + {"bias_add_linear_combination_hswish", "BiasAddLinearCombinationHSwish", + epilogue::EpilogueType::kBiasAddLinearCombinationHSwish}, + {"bias_add_linear_combination_hswish_clamp", + "BiasAddLinearCombinationHSwishClamp", + epilogue::EpilogueType::kBiasAddLinearCombinationHSwishClamp}, + {"bias_add_linear_combination_relu", "BiasAddLinearCombinationRelu", + epilogue::EpilogueType::kBiasAddLinearCombinationRelu}, + {"bias_add_linear_combination_relu_clamp", + "BiasAddLinearCombinationReluClamp", + epilogue::EpilogueType::kBiasAddLinearCombinationReluClamp}, + {"conversion", "Conversion", epilogue::EpilogueType::kConversion}, + {"linear_combination", "LinearCombination", + epilogue::EpilogueType::kLinearCombination}, + {"linear_combination_clamp", "LinearCombination_clamp", + epilogue::EpilogueType::kLinearCombinationClamp}, + {"linear_combination_planar_complex", "LinearCombinationPlanarComplex", + epilogue::EpilogueType::kLinearCombinationPlanarComplex}, + {"linear_combination_relu", "LinearCombinationRelu", + epilogue::EpilogueType::kLinearCombinationRelu}, + {"linear_combination_sigmoid", "LinearCombinationSigmoid", + epilogue::EpilogueType::kLinearCombinationSigmoid}, +}; + +/// Converts an EpilogueType enumerant to a string +char const* to_string(epilogue::EpilogueType type, bool pretty) { + for (auto const& possible : EpilogueType_enumerants) { + if (type == possible.enumerant) { + if (pretty) { + return possible.pretty; + } else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + char const* text; + char const* pretty; + ThreadblockSwizzleID enumerant; +} ThreadblockSwizzleID_enumerants[] = { + {"gemm_identity", "GemmIdentityThreadblockSwizzle", + ThreadblockSwizzleID::kGemmIdentity}, + {"gemm_horizontal", "GemmHorizontalThreadblockSwizzle", + ThreadblockSwizzleID::kGemmHorizontal}, + {"gemm_batched_identity", "GemmBatchedIdentityThreadblockSwizzle", + ThreadblockSwizzleID::kGemmBatchedIdentity}, + {"gemm_split_k_identity", "GemmSplitKIdentityThreadblockSwizzle", + ThreadblockSwizzleID::kGemmSplitKIdentity}, + {"gemm_split_k_horizontal", "GemmSplitKHorizontalThreadblockSwizzle", + ThreadblockSwizzleID::kGemmSplitKHorizontal}, + {"gemv_batched_strided_default", + "GemvBatchedStridedThreadblockDefaultSwizzle", + ThreadblockSwizzleID::kGemvBatchedStridedDefault}, + {"gemv_batched_strided_reduction", + "GemvBatchedStridedThreadblockReductionSwizzle", + ThreadblockSwizzleID::kGemvBatchedStridedReduction}, + {"convolution_fprop_cxrskx", "ConvolutionFpropCxRSKxThreadblockSwizzle", + ThreadblockSwizzleID::kConvolutionFpropCxRSKx}, + {"convolution_dgrad_cxrskx", "ConvolutionDgradCxRSKxThreadblockSwizzle", + ThreadblockSwizzleID::kConvolutionDgradCxRSKx}, + {"convolution_fprop_ncxhwx", "ConvolutionFpropNCxHWxThreadblockSwizzle", + ThreadblockSwizzleID::kConvolutionFpropNCxHWx}, + {"convolution_fprop_nhwc", "ConvolutionFpropTransThreadblockSwizzle", + ThreadblockSwizzleID::kConvolutionFpropTrans}, + {"convolution_dgrad_ncxhwx", "ConvolutionDgradNCxHWxThreadblockSwizzle", + ThreadblockSwizzleID::kConvolutionDgradNCxHWx}, +}; + +/// Converts a ThreadblockSwizzleID enumerant to a string +char const* to_string(ThreadblockSwizzleID threadblock_swizzle, bool pretty) { + for (auto const& possible : ThreadblockSwizzleID_enumerants) { + if (threadblock_swizzle == possible.enumerant) { + if (pretty) { + return possible.pretty; + } else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Converts a bool value to a string +char const* to_string(bool val, bool pretty) { + if (val) { + return pretty ? "True" : "true"; + } else { + return pretty ? "False" : "false"; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + char const* text; + char const* pretty; + MathOperationID enumerant; +} MathOperationID_enumerants[] = { + {"add", "Add", MathOperationID::kAdd}, + {"multiply_add", "MultiplyAdd", MathOperationID::kMultiplyAdd}, + {"multiply_add_saturate", "MultiplyAddSaturate", + MathOperationID::kMultiplyAddSaturate}, + {"multiply_add_fast_bf16", "MultiplyAddFastBF16", + MathOperationID::kMultiplyAddFastBF16}, + {"multiply_add_fast_f16", "MultiplyAddFastF16", + MathOperationID::kMultiplyAddFastF16}, + {"multiply_add_complex", "MultiplyAddComplex", + MathOperationID::kMultiplyAddComplex}, + {"multiply_add_gaussian_complex", "MultiplyAddGaussianComplex", + MathOperationID::kMultiplyAddGaussianComplex}, + {"xor_popc", "XorPopc", MathOperationID::kXorPopc}, +}; + +/// Converts a MathOperationID enumerant to a string +char const* to_string(MathOperationID math_op, bool pretty) { + for (auto const& possible : MathOperationID_enumerants) { + if (math_op == possible.enumerant) { + if (pretty) { + return possible.pretty; + } else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + char const* text; + char const* pretty; + conv::ImplicitGemmMode enumerant; +} ImplicitGemmMode_enumerants[] = { + {"gemm_nt", "GemmNT", conv::ImplicitGemmMode::GEMM_NT}, + {"gemm_tn", "GemmTN", conv::ImplicitGemmMode::GEMM_TN}, +}; + +/// Converts an ImplicitGemmMode enumerant to a string +char const* to_string(conv::ImplicitGemmMode mode, bool pretty) { + for (auto const& possible : ImplicitGemmMode_enumerants) { + if (mode == possible.enumerant) { + if (pretty) { + return possible.pretty; + } else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif diff --git a/dnn/src/cuda/cutlass/util.h b/dnn/src/cuda/cutlass/util.h new file mode 100644 index 00000000..cb6e6f99 --- /dev/null +++ b/dnn/src/cuda/cutlass/util.h @@ -0,0 +1,218 @@ +/*************************************************************************************************** + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + *modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + *notice, this list of conditions and the following disclaimer in the + *documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its + *contributors may be used to endorse or promote products derived from this + *software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, + *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING + *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/** + * \file dnn/src/cuda/cutlass/util.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "src/cuda/cutlass/library.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Lexical cast from string +template +T from_string(std::string const&); + +/// Converts a Provider enumerant to a string +char const* to_string(Provider provider, bool pretty = false); + +/// Parses a Provider enumerant from a string +template <> +Provider from_string(std::string const& str); + +/// Converts a GemmKind enumerant to a string +char const* to_string(GemmKind type, bool pretty = false); + +/// Converts a NumericType enumerant to a string +char const* to_string(OperationKind type, bool pretty = false); + +/// Parses a NumericType enumerant from a string +template <> +OperationKind from_string(std::string const& str); + +/// Converts a NumericType enumerant to a string +char const* to_string(NumericTypeID type, bool pretty = false); + +/// Parses a NumericType enumerant from a string +template <> +NumericTypeID from_string(std::string const& str); + +/// Returns the size of a data type in bits +int sizeof_bits(NumericTypeID type); + +/// Returns true if the numeric type is a complex data type or false if +/// real-valued. +bool is_complex_type(NumericTypeID type); + +/// Returns the real-valued type underlying a type (only different from 'type' +/// if complex) +NumericTypeID get_real_type(NumericTypeID type); + +/// Returns true if numeric type is integer +bool is_integer_type(NumericTypeID type); + +/// Returns true if numeric type is signed +bool is_signed_type(NumericTypeID type); + +/// Returns true if numeric type is a signed integer +bool is_signed_integer(NumericTypeID type); + +/// returns true if numeric type is an unsigned integer +bool is_unsigned_integer(NumericTypeID type); + +/// Returns true if numeric type is floating-point type +bool is_float_type(NumericTypeID type); + +/// To string method for cutlass::Status +char const* to_string(Status status, bool pretty = false); + +/// Converts a LayoutTypeID enumerant to a string +char const* to_string(LayoutTypeID layout, bool pretty = false); + +/// Parses a LayoutType enumerant from a string +template <> +LayoutTypeID from_string(std::string const& str); + +/// Returns the rank of a layout's stride base on the LayoutTypeID +int get_layout_stride_rank(LayoutTypeID layout_id); + +/// Converts a OpcodeClassID enumerant to a string +char const* to_string(OpcodeClassID type, bool pretty = false); + +/// Converts a OpcodeClassID enumerant from a string +template <> +OpcodeClassID from_string(std::string const& str); + +/// Converts a ComplexTransform enumerant to a string +char const* to_string(ComplexTransform type, bool pretty = false); + +/// Converts a ComplexTransform enumerant from a string +template <> +ComplexTransform from_string(std::string const& str); + +/// Converts a SplitKMode enumerant to a string +char const* to_string(SplitKMode split_k_mode, bool pretty = false); + +/// Converts a SplitKMode enumerant from a string +template <> +SplitKMode from_string(std::string const& str); + +/// Converts a ConvModeID enumerant to a string +char const* to_string(ConvModeID type, bool pretty = false); + +/// Converts a ConvModeID enumerant from a string +template <> +ConvModeID from_string(std::string const& str); + +/// Converts a IteratorAlgorithmID enumerant to a string +char const* to_string(IteratorAlgorithmID type, bool pretty = false); + +/// Converts a IteratorAlgorithmID enumerant from a string +template <> +IteratorAlgorithmID from_string(std::string const& str); + +/// Converts a ConvKind enumerant to a string +char const* to_string(ConvKind type, bool pretty = false); + +/// Converts a ConvKind enumerant from a string +template <> +ConvKind from_string(std::string const& str); + +/// Lexical cast from int64_t to string +std::string lexical_cast(int64_t int_value); + +/// Lexical cast a string to a byte array. Returns true if cast is successful or +/// false if invalid. +bool lexical_cast(std::vector& bytes, NumericTypeID type, + std::string const& str); + +/// Lexical cast TO a string FROM a byte array. Returns true if cast is +/// successful or false if invalid. +std::string lexical_cast(std::vector& bytes, NumericTypeID type); + +/// Casts from a signed int64 to the destination type. Returns true if +/// successful. +bool cast_from_int64(std::vector& bytes, NumericTypeID type, + int64_t src); + +/// Casts from an unsigned int64 to the destination type. Returns true if +/// successful. +bool cast_from_uint64(std::vector& bytes, NumericTypeID type, + uint64_t src); + +/// Casts from a real value represented as a double to the destination type. +/// Returns true if successful. +bool cast_from_double(std::vector& bytes, NumericTypeID type, + double src); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Converts a conv::Operator enumerant to a string +char const* to_string(conv::Operator conv_op, bool pretty = false); + +/// Converts a ConvType enumerant to a string +char const* to_string(conv::ConvType type, bool pretty = false); + +/// Converts an ArchTagID enumerant to a string +char const* to_string(ArchTagID tag, bool pretty = false); + +/// Converts an EpilogueType enumerant to a string +char const* to_string(epilogue::EpilogueType type, bool pretty = false); + +/// Converts a ThreadblockSwizzleID enumerant to a string +char const* to_string(ThreadblockSwizzleID threadblock_swizzle, + bool pretty = false); + +/// Converts a bool value to a string +char const* to_string(bool val, bool pretty = false); + +/// Converts a MathOperationID enumerant to a string +char const* to_string(MathOperationID math_op, bool pretty = false); + +/// Converts an ImplicitGemmMode enumerant to a string +char const* to_string(conv::ImplicitGemmMode mode, bool pretty = false); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp b/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp index 51a8a7ee..05fa960b 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp +++ b/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp @@ -10,15 +10,14 @@ * implied. */ +#include "src/cuda/cutlass/singleton.h" #include "src/cuda/handle.h" #include "src/cuda/matrix_mul/algos.h" -#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" #include "src/cuda/utils.h" #if CUDA_VERSION >= 9020 using namespace megdnn; using namespace cuda; -using namespace cutlass_wrapper; bool MatrixMulForwardImpl::AlgoFloat32SIMT::is_available( const SizeArgs& args) const { @@ -44,25 +43,62 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes( } void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { - size_t lda = args.tensor_a.layout.stride[0], - ldb = args.tensor_b.layout.stride[0], - ldc = args.tensor_c.layout.stride[0]; + int64_t lda = args.tensor_a.layout.stride[0], + ldb = args.tensor_b.layout.stride[0], + ldc = args.tensor_c.layout.stride[0]; auto&& param = args.opr->param(); int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; - GemmCoord problem_size{m, n, k}; + cutlass::gemm::GemmCoord problem_size{m, n, k}; auto&& stream = cuda_stream(args.opr->handle()); int* workspace = reinterpret_cast(args.workspace.raw_ptr); - return cutlass_matrix_mul_float32_simt( - args.tensor_a.ptr(), param.transposeA, lda, - args.tensor_b.ptr(), param.transposeB, ldb, - args.tensor_c.ptr(), ldc, workspace, problem_size, 1.f, - 0.f, - GemmCoord{m_algo_param.threadblock_m, m_algo_param.threadblock_n, - m_algo_param.threadblock_k}, - GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n, - m_algo_param.warp_k}, - stream); + + // \note these constants of cutlass epilogue will be passed to struct + // `GemmArguments` by pointer and interpreted as ElementCompute*, a + // different dtype here results in undefined epilogue behaviors + float alpha = 1.f, beta = 0.f; + + using namespace cutlass::library; + + auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor + : LayoutTypeID::kRowMajor; + auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor + : LayoutTypeID::kRowMajor; + + GemmKey key{NumericTypeID::kF32, + layoutA, + NumericTypeID::kF32, + layoutB, + NumericTypeID::kF32, + LayoutTypeID::kRowMajor, + m_algo_param.threadblock_m, + m_algo_param.threadblock_n, + m_algo_param.threadblock_k, + m_algo_param.warp_m, + m_algo_param.warp_n, + m_algo_param.warp_k, + 1, + 1, + 1, + 2, + SplitKMode::kNone}; + + const Operation* op = Singleton::get().operation_table.find_op(key); + + GemmArguments gemm_args{problem_size, + args.tensor_a.raw_ptr, + args.tensor_b.raw_ptr, + args.tensor_c.raw_ptr, + args.tensor_c.raw_ptr, + lda, + ldb, + ldc, + ldc, + 1, + &alpha, + &beta}; + + cutlass_check(op->run(&gemm_args, workspace, stream)); } #endif diff --git a/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp b/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp index 5d68dfec..2baa7f63 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp +++ b/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp @@ -10,15 +10,14 @@ * implied. */ +#include "src/cuda/cutlass/singleton.h" #include "src/cuda/handle.h" #include "src/cuda/matrix_mul/algos.h" -#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" #include "src/cuda/utils.h" #if CUDA_VERSION >= 9020 using namespace megdnn; using namespace cuda; -using namespace cutlass_wrapper; bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( const SizeArgs& args) const { @@ -50,26 +49,63 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( const ExecArgs& args) const { - size_t lda = args.tensor_a.layout.stride[0], - ldb = args.tensor_b.layout.stride[0], - ldc = args.tensor_c.layout.stride[0]; + int64_t lda = args.tensor_a.layout.stride[0], + ldb = args.tensor_b.layout.stride[0], + ldc = args.tensor_c.layout.stride[0]; auto&& param = args.opr->param(); int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; - GemmCoord problem_size{m, n, k}; + cutlass::gemm::GemmCoord problem_size{m, n, k}; int split_k_slices = std::max(1, k / n); auto&& stream = cuda_stream(args.opr->handle()); int* workspace = reinterpret_cast(args.workspace.raw_ptr); - return cutlass_matrix_mul_float32_simt( - args.tensor_a.ptr(), param.transposeA, lda, - args.tensor_b.ptr(), param.transposeB, ldb, - args.tensor_c.ptr(), ldc, workspace, problem_size, 1.f, - 0.f, - GemmCoord{m_algo_param.threadblock_m, m_algo_param.threadblock_n, - m_algo_param.threadblock_k}, - GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n, - m_algo_param.warp_k}, - stream, split_k_slices); + + // \note these constants of cutlass epilogue will be passed to struct + // `GemmArguments` by pointer and interpreted as ElementCompute*, a + // different dtype here results in undefined epilogue behaviors + float alpha = 1.f, beta = 0.f; + + using namespace cutlass::library; + + auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor + : LayoutTypeID::kRowMajor; + auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor + : LayoutTypeID::kRowMajor; + + GemmKey key{NumericTypeID::kF32, + layoutA, + NumericTypeID::kF32, + layoutB, + NumericTypeID::kF32, + LayoutTypeID::kRowMajor, + m_algo_param.threadblock_m, + m_algo_param.threadblock_n, + m_algo_param.threadblock_k, + m_algo_param.warp_m, + m_algo_param.warp_n, + m_algo_param.warp_k, + 1, + 1, + 1, + 2, + SplitKMode::kParallel}; + + Operation const* op = Singleton::get().operation_table.find_op(key); + + GemmArguments gemm_args{problem_size, + args.tensor_a.raw_ptr, + args.tensor_b.raw_ptr, + args.tensor_c.raw_ptr, + args.tensor_c.raw_ptr, + lda, + ldb, + ldc, + ldc, + split_k_slices, + &alpha, + &beta}; + + cutlass_check(op->run(&gemm_args, workspace, stream)); } #endif diff --git a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu deleted file mode 100644 index fe478b83..00000000 --- a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu +++ /dev/null @@ -1,157 +0,0 @@ -/** - * \file dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ -// ignore warning of cutlass -#include "cuda.h" -#if __CUDACC_VER_MAJOR__ > 9 || \ - (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#pragma GCC diagnostic ignored "-Wstrict-aliasing" - -#include "cutlass/gemm/device/gemm.h" -#include "cutlass/gemm/device/gemm_splitk_parallel.h" -#include "cutlass/gemm/kernel/default_gemv.h" -#include "src/common/opr_param_defs_enumv.cuh" -#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" -#pragma GCC diagnostic pop - -using namespace megdnn; -using namespace cuda; -using namespace cutlass_wrapper; - -/* ================= cutlass kernel wrapper for f32 matrix mul ================ - */ -#define DISPATCH(cb) \ - cb(64, 256, 8, 32, 64, 8); \ - cb(256, 64, 8, 64, 32, 8); \ - cb(32, 256, 8, 16, 64, 8); \ - cb(256, 32, 8, 64, 16, 8); \ - cb(128, 128, 8, 32, 64, 8); \ - cb(128, 64, 8, 64, 32, 8); \ - cb(64, 128, 8, 32, 64, 8); \ - cb(128, 32, 8, 64, 32, 8); \ - cb(32, 128, 8, 32, 64, 8); \ - cb(64, 64, 8, 32, 64, 8); \ - cb(32, 64, 8, 32, 64, 8); \ - cb(64, 32, 8, 64, 32, 8); \ - cb(32, 32, 8, 32, 32, 8); \ - cb(8, 32, 8, 8, 32, 8); \ - cb(16, 32, 8, 16, 32, 8); \ - cb(16, 64, 8, 16, 64, 8); \ - cb(16, 128, 8, 16, 64, 8); \ - megdnn_assert(false, \ - "unsupported threadblock shape (%dx%dx%d) and warp shape " \ - "(%dx%dx%d)", \ - threadblock_shape.m(), threadblock_shape.n(), \ - threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ - warp_shape.k()); -void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt( - const float* d_A, bool transpose_A, size_t lda, const float* d_B, - bool transpose_B, size_t ldb, float* d_C, size_t ldc, int* workspace, - GemmCoord const& problem_size, float alpha, float beta, - const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, - cudaStream_t stream, int split_k_slices) { - static constexpr int kEpilogueElementsPerAccess = 1; - using EpilogueOp = cutlass::epilogue::thread::LinearCombination< - float, kEpilogueElementsPerAccess, float, float>; - typename EpilogueOp::Params epilogue{alpha, beta}; - if (split_k_slices == 1) { -#define cb(threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ - warp_k_) \ - if (threadblock_shape.m() == threadblock_m_ && \ - threadblock_shape.n() == threadblock_n_ && \ - threadblock_shape.k() == threadblock_k_ && \ - warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ - warp_shape.k() == warp_k_) { \ - using ThreadBlockShape = \ - cutlass::gemm::GemmShape; \ - using WarpShape = cutlass::gemm::GemmShape; \ - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; \ - using Gemm = cutlass::gemm::device::Gemm< \ - float, LayoutA, float, LayoutB, float, \ - cutlass::layout::RowMajor, float, cutlass::arch::OpClassSimt, \ - cutlass::arch::Sm50, ThreadBlockShape, WarpShape, \ - InstructionShape, EpilogueOp, \ - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, \ - 2>; \ - return cutlass_matrix_mul_wrapper(d_A, lda, d_B, ldb, d_C, ldc, \ - workspace, problem_size, \ - epilogue, stream); \ - } - if (!transpose_A && !transpose_B) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::RowMajor; - DISPATCH(cb) - } else if (!transpose_A && transpose_B) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - DISPATCH(cb) - } else if (transpose_A && !transpose_B) { - using LayoutA = cutlass::layout::ColumnMajor; - using LayoutB = cutlass::layout::RowMajor; - DISPATCH(cb) - } else { - megdnn_assert(transpose_A && transpose_B); - using LayoutA = cutlass::layout::ColumnMajor; - using LayoutB = cutlass::layout::ColumnMajor; - DISPATCH(cb) - } -#undef cb - } else { -#define cb(threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ - warp_k_) \ - if (threadblock_shape.m() == threadblock_m_ && \ - threadblock_shape.n() == threadblock_n_ && \ - threadblock_shape.k() == threadblock_k_ && \ - warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ - warp_shape.k() == warp_k_) { \ - using ThreadBlockShape = \ - cutlass::gemm::GemmShape; \ - using WarpShape = cutlass::gemm::GemmShape; \ - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; \ - using Gemm = cutlass::gemm::device::GemmSplitKParallel< \ - float, LayoutA, float, LayoutB, float, \ - cutlass::layout::RowMajor, float, cutlass::arch::OpClassSimt, \ - cutlass::arch::Sm50, ThreadBlockShape, WarpShape, \ - InstructionShape, EpilogueOp>; \ - return cutlass_matrix_mul_wrapper( \ - d_A, lda, d_B, ldb, d_C, ldc, workspace, problem_size, \ - epilogue, stream, split_k_slices); \ - } - if (!transpose_A && !transpose_B) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::RowMajor; - DISPATCH(cb) - } else if (!transpose_A && transpose_B) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - DISPATCH(cb) - } else if (transpose_A && !transpose_B) { - using LayoutA = cutlass::layout::ColumnMajor; - using LayoutB = cutlass::layout::RowMajor; - DISPATCH(cb) - } else { - megdnn_assert(transpose_A && transpose_B); - using LayoutA = cutlass::layout::ColumnMajor; - using LayoutB = cutlass::layout::ColumnMajor; - DISPATCH(cb) - } -#undef cb - } -} -#undef DISPATCH - -#endif - -// vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh index 86144b91..371c04a9 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh +++ b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh @@ -21,22 +21,6 @@ namespace cutlass_wrapper { using GemmCoord = cutlass::gemm::GemmCoord; using BatchedGemmCoord = cutlass::gemm::BatchedGemmCoord; -template -void cutlass_matrix_mul_wrapper( - const typename Gemm::ElementA* d_A, size_t lda, - const typename Gemm::ElementB* d_B, size_t ldb, - typename Gemm::ElementC* d_C, size_t ldc, int* workspace, - GemmCoord const& problem_size, - typename Gemm::EpilogueOutputOp::Params const& epilogue, - cudaStream_t stream, int split_k_slices = 1); - -void cutlass_matrix_mul_float32_simt( - const float* d_A, bool transpose_A, size_t lda, const float* d_B, - bool transpose_B, size_t ldb, float* d_C, size_t ldc, int* workspace, - GemmCoord const& problem_size, float alpha, float beta, - const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, - cudaStream_t stream, int split_k_slices = 1); - template void cutlass_vector_matrix_mul_batched_strided_wrapper( BatchedGemmCoord const& problem_size, diff --git a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuinl b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuinl deleted file mode 100644 index 4610a8f6..00000000 --- a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuinl +++ /dev/null @@ -1,57 +0,0 @@ -/** - * \file - * dnn/src/cuda/matrix_mul/matrix_mul_float_simt_cutlass_wrapper.cuinl - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ -#include "cutlass/gemm/device/gemm.h" -#include "cutlass/gemm/device/gemm_splitk_parallel.h" -#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" - -using namespace megdnn; -using namespace cuda; -using namespace cutlass_wrapper; - -template -void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_wrapper( - const typename Gemm::ElementA* d_A, size_t lda, - const typename Gemm::ElementB* d_B, size_t ldb, - typename Gemm::ElementC* d_C, size_t ldc, int* workspace, - GemmCoord const& problem_size, - typename Gemm::EpilogueOutputOp::Params const& epilogue, - cudaStream_t stream, int split_k_slices) { - using TensorRefA = cutlass::TensorRef; - using TensorRefB = cutlass::TensorRef; - using TensorRefC = cutlass::TensorRef; - using TensorRefD = - cutlass::TensorRef; - TensorRefA tensor_a{const_cast(d_A), - typename Gemm::LayoutA{static_cast(lda)}}; - TensorRefB tensor_b{const_cast(d_B), - typename Gemm::LayoutB{static_cast(ldb)}}; - TensorRefC tensor_c{nullptr, typename Gemm::LayoutC{static_cast(ldc)}}; - TensorRefD tensor_d{d_C, typename Gemm::LayoutC{static_cast(ldc)}}; - - typename Gemm::Arguments arguments{problem_size, - tensor_a, - tensor_b, - tensor_c, - tensor_d.non_const_ref(), - epilogue, - split_k_slices}; - Gemm gemm_op; - cutlass_check(gemm_op.initialize(arguments, workspace)); - cutlass_check(gemm_op(stream)); - after_kernel_launch(); -} - -// vim: syntax=cuda.doxygen