From 336761253deccab67eafb680f40186b30e973cf0 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 13 Jul 2021 17:15:35 +0800 Subject: [PATCH] feat(dnn/cuda): add tensorcore matmul for fp16 data type GitOrigin-RevId: 025c591f75afcef8fd58034a9cdd1ae8528bbda1 --- dnn/scripts/cutlass_generator/BUILD | 2 + dnn/scripts/cutlass_generator/gemm_operation.py | 25 +- dnn/scripts/cutlass_generator/gen_list.py | 2 + dnn/scripts/cutlass_generator/generator.py | 138 +++++++++- dnn/scripts/cutlass_generator/list.bzl | 292 ++++++++++++++++++++- dnn/src/CMakeLists.txt | 2 + dnn/src/cuda/cutlass/initialize_all.cu | 4 + dnn/src/cuda/cutlass/operation_table.cpp | 4 + dnn/src/cuda/cutlass/operation_table.h | 15 +- dnn/src/cuda/matrix_mul/algos.cpp | 21 +- dnn/src/cuda/matrix_mul/algos.h | 130 +++++++-- .../cuda/matrix_mul/cutlass_float16_tensorop.cpp | 154 +++++++++++ .../cutlass_float16_tensorop_split_k.cpp | 165 ++++++++++++ dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp | 11 +- .../matrix_mul/cutlass_float32_simt_split_k.cpp | 14 +- .../cuda/matrix_mul/cutlass_matrix_mul_base.cpp | 136 ++++++++++ dnn/src/cuda/matrix_mul/opr_impl.h | 3 + dnn/test/common/matrix_mul.cpp | 4 +- dnn/test/common/matrix_mul.h | 4 +- dnn/test/cuda/cutlass_matmul.cpp | 91 +++++++ 20 files changed, 1168 insertions(+), 49 deletions(-) create mode 100644 dnn/src/cuda/matrix_mul/cutlass_float16_tensorop.cpp create mode 100644 dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp create mode 100644 dnn/src/cuda/matrix_mul/cutlass_matrix_mul_base.cpp diff --git a/dnn/scripts/cutlass_generator/BUILD b/dnn/scripts/cutlass_generator/BUILD index cb54d00c..bb1e0b5f 100644 --- a/dnn/scripts/cutlass_generator/BUILD +++ b/dnn/scripts/cutlass_generator/BUILD @@ -5,6 +5,8 @@ genrule( outs = cutlass_gen_list, cmd = """GEN=$(location //brain/megbrain/dnn/scripts/cutlass_generator:generator.py) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type simt $(@D) + CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type tensorop884 $(@D) + CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type tensorop1688 $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemv --type simt $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type simt $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type simt $(@D) diff --git a/dnn/scripts/cutlass_generator/gemm_operation.py b/dnn/scripts/cutlass_generator/gemm_operation.py index 90b1c940..a6f33b18 100644 --- a/dnn/scripts/cutlass_generator/gemm_operation.py +++ b/dnn/scripts/cutlass_generator/gemm_operation.py @@ -252,7 +252,8 @@ def GeneratesGemm(tile, data_type, layout_a, layout_b, layout_c, min_cc, align_a if tile.math_instruction.element_accumulator == DataType.s32: epilogues = [EpilogueFunctor.LinearCombinationClamp] else: - assert tile.math_instruction.element_accumulator == DataType.f32 + assert tile.math_instruction.element_accumulator == DataType.f32 or \ + tile.math_instruction.element_accumulator == DataType.f16 epilogues = [EpilogueFunctor.LinearCombination] for epilogue in epilogues: @@ -799,7 +800,22 @@ class EmitGemmSplitKParallelInstance: ${epilogue_vector_length}, ${element_accumulator}, ${element_epilogue} - > + >, + cutlass::epilogue::thread::Convert< + ${element_accumulator}, + ${epilogue_vector_length}, + ${element_accumulator} + >, + cutlass::reduction::thread::ReduceAdd< + ${element_accumulator}, + ${element_accumulator}, + ${epilogue_vector_length} + >, + cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle, + ${stages}, + ${align_a}, + ${align_b}, + ${math_operation} >; """ def emit(self, operation): @@ -831,7 +847,10 @@ class EmitGemmSplitKParallelInstance: 'epilogue_vector_length': str(epilogue_vector_length), 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), } return SubstituteTemplate(self.template, values) diff --git a/dnn/scripts/cutlass_generator/gen_list.py b/dnn/scripts/cutlass_generator/gen_list.py index 7c61f73d..d3b06776 100644 --- a/dnn/scripts/cutlass_generator/gen_list.py +++ b/dnn/scripts/cutlass_generator/gen_list.py @@ -32,6 +32,8 @@ if __name__ == "__main__": f.write("# Generated by dnn/scripts/cutlass_generator/gen_list.py\n\n") f.write("cutlass_gen_list = [\n") write_op_list(f, "gemm", "simt") + write_op_list(f, "gemm", "tensorop1688") + write_op_list(f, "gemm", "tensorop884") write_op_list(f, "gemv", "simt") write_op_list(f, "deconv", "simt") write_op_list(f, "conv2d", "simt") diff --git a/dnn/scripts/cutlass_generator/generator.py b/dnn/scripts/cutlass_generator/generator.py index b49aec21..cd7f810d 100644 --- a/dnn/scripts/cutlass_generator/generator.py +++ b/dnn/scripts/cutlass_generator/generator.py @@ -597,6 +597,131 @@ def GenerateGemv_Simt(args): return operations # +def GeneratesGemm_TensorOp_1688(args): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [8, 4, 2, + #1 + ] + + operations = [] + for math_inst in math_instructions: + for layout in layouts: + for align in alignment_constraints: + tile_descriptions = [ + TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), +## comment some configuration to reduce compilation time and binary size +# TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), +# TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), +# TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + for tile in tile_descriptions: + operations += GeneratesGemm(tile, \ + data_type, \ + layout[0], \ + layout[1], \ + layout[2], \ + min_cc, \ + align * 16, \ + align * 16, \ + align * 16) + return operations + +# +def GeneratesGemm_TensorOp_884(args): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 70 + max_cc = 75 + + alignment_constraints = [8, 4, 2, + # 1 + ] + + operations = [] + for math_inst in math_instructions: + for layout in layouts: + for align in alignment_constraints: + tile_descriptions = [ + TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), +## comment some configuration to reduce compilation time and binary size +# TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), +# TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), +# TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + for tile in tile_descriptions: + operations += GeneratesGemm(tile, \ + data_type, \ + layout[0], \ + layout[1], \ + layout[2], \ + min_cc, \ + align * 16, \ + align * 16, \ + align * 16) + + return operations + +# def GenerateConv2dOperations(args): if args.type == "simt": return GenerateConv2d_Simt(args) @@ -613,9 +738,14 @@ def GenerateDeconvOperations(args): return GenerateDeconv_Simt(args) def GenerateGemmOperations(args): - assert args.type == "simt", "operation gemm only support" \ - "simt. (got:{})".format(args.type) - return GenerateGemm_Simt(args) + if args.type == "tensorop884": + return GeneratesGemm_TensorOp_884(args) + elif args.type == "tensorop1688": + return GeneratesGemm_TensorOp_1688(args) + else: + assert args.type == "simt", "operation gemm only support" \ + "simt. (got:{})".format(args.type) + return GenerateGemm_Simt(args) def GenerateGemvOperations(args): assert args.type == "simt", "operation gemv only support" \ @@ -631,7 +761,7 @@ if __name__ == "__main__": parser.add_argument("--operations", type=str, choices=['gemm', 'gemv', 'conv2d', 'deconv'], required=True, help="Specifies the operation to generate (gemm, gemv, conv2d, deconv)") parser.add_argument("output", type=str, help="output directory for CUTLASS kernel files") - parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832'], + parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832', 'tensorop884', 'tensorop1688'], default='simt', help="kernel type of CUTLASS kernel generator") gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl" diff --git a/dnn/scripts/cutlass_generator/list.bzl b/dnn/scripts/cutlass_generator/list.bzl index 0d4b649a..7aaae2d6 100644 --- a/dnn/scripts/cutlass_generator/list.bzl +++ b/dnn/scripts/cutlass_generator/list.bzl @@ -138,6 +138,296 @@ cutlass_gen_list = [ "cutlass_simt_sgemm_256x64_8x2_tt_align1.cu", "cutlass_simt_sgemm_split_k_parallel_256x64_8x2_tt_align1.cu", "all_gemm_simt_operations.cu", + "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align8.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nn_align8.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nn_align8.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nn_align8.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nn_align8.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nn_align8.cu", + "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align4.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nn_align4.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nn_align4.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nn_align4.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nn_align4.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nn_align4.cu", + "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align2.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nn_align2.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nn_align2.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nn_align2.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nn_align2.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nn_align2.cu", + "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nt_align8.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nt_align8.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nt_align8.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nt_align8.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nt_align8.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nt_align8.cu", + "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nt_align4.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nt_align4.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nt_align4.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nt_align4.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nt_align4.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nt_align4.cu", + "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nt_align2.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nt_align2.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nt_align2.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nt_align2.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nt_align2.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nt_align2.cu", + "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tn_align8.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tn_align8.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tn_align8.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tn_align8.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tn_align8.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tn_align8.cu", + "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tn_align4.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tn_align4.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tn_align4.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tn_align4.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tn_align4.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tn_align4.cu", + "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tn_align2.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tn_align2.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tn_align2.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tn_align2.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tn_align2.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tn_align2.cu", + "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tt_align8.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tt_align8.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tt_align8.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tt_align8.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tt_align8.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tt_align8.cu", + "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tt_align4.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tt_align4.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tt_align4.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tt_align4.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tt_align4.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tt_align4.cu", + "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tt_align2.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tt_align2.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tt_align2.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tt_align2.cu", + "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tt_align2.cu", + "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tt_align2.cu", + "cutlass_tensorop_h1688gemm_256x128_32x2_nn_align8.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nn_align8.cu", + "cutlass_tensorop_h1688gemm_128x256_32x2_nn_align8.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nn_align8.cu", + "cutlass_tensorop_h1688gemm_128x128_32x2_nn_align8.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nn_align8.cu", + "cutlass_tensorop_h1688gemm_256x128_32x2_nn_align4.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nn_align4.cu", + "cutlass_tensorop_h1688gemm_128x256_32x2_nn_align4.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nn_align4.cu", + "cutlass_tensorop_h1688gemm_128x128_32x2_nn_align4.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nn_align4.cu", + "cutlass_tensorop_h1688gemm_256x128_32x2_nn_align2.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nn_align2.cu", + "cutlass_tensorop_h1688gemm_128x256_32x2_nn_align2.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nn_align2.cu", + "cutlass_tensorop_h1688gemm_128x128_32x2_nn_align2.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nn_align2.cu", + "cutlass_tensorop_h1688gemm_256x128_32x2_nt_align8.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nt_align8.cu", + "cutlass_tensorop_h1688gemm_128x256_32x2_nt_align8.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nt_align8.cu", + "cutlass_tensorop_h1688gemm_128x128_32x2_nt_align8.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nt_align8.cu", + "cutlass_tensorop_h1688gemm_256x128_32x2_nt_align4.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nt_align4.cu", + "cutlass_tensorop_h1688gemm_128x256_32x2_nt_align4.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nt_align4.cu", + "cutlass_tensorop_h1688gemm_128x128_32x2_nt_align4.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nt_align4.cu", + "cutlass_tensorop_h1688gemm_256x128_32x2_nt_align2.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nt_align2.cu", + "cutlass_tensorop_h1688gemm_128x256_32x2_nt_align2.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nt_align2.cu", + "cutlass_tensorop_h1688gemm_128x128_32x2_nt_align2.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nt_align2.cu", + "cutlass_tensorop_h1688gemm_256x128_32x2_tn_align8.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tn_align8.cu", + "cutlass_tensorop_h1688gemm_128x256_32x2_tn_align8.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tn_align8.cu", + "cutlass_tensorop_h1688gemm_128x128_32x2_tn_align8.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tn_align8.cu", + "cutlass_tensorop_h1688gemm_256x128_32x2_tn_align4.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tn_align4.cu", + "cutlass_tensorop_h1688gemm_128x256_32x2_tn_align4.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tn_align4.cu", + "cutlass_tensorop_h1688gemm_128x128_32x2_tn_align4.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tn_align4.cu", + "cutlass_tensorop_h1688gemm_256x128_32x2_tn_align2.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tn_align2.cu", + "cutlass_tensorop_h1688gemm_128x256_32x2_tn_align2.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tn_align2.cu", + "cutlass_tensorop_h1688gemm_128x128_32x2_tn_align2.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tn_align2.cu", + "cutlass_tensorop_h1688gemm_256x128_32x2_tt_align8.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tt_align8.cu", + "cutlass_tensorop_h1688gemm_128x256_32x2_tt_align8.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tt_align8.cu", + "cutlass_tensorop_h1688gemm_128x128_32x2_tt_align8.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tt_align8.cu", + "cutlass_tensorop_h1688gemm_256x128_32x2_tt_align4.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tt_align4.cu", + "cutlass_tensorop_h1688gemm_128x256_32x2_tt_align4.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tt_align4.cu", + "cutlass_tensorop_h1688gemm_128x128_32x2_tt_align4.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tt_align4.cu", + "cutlass_tensorop_h1688gemm_256x128_32x2_tt_align2.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tt_align2.cu", + "cutlass_tensorop_h1688gemm_128x256_32x2_tt_align2.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tt_align2.cu", + "cutlass_tensorop_h1688gemm_128x128_32x2_tt_align2.cu", + "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tt_align2.cu", + "all_gemm_tensorop1688_operations.cu", + "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nn_align8.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nn_align8.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nn_align8.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nn_align8.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nn_align8.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nn_align8.cu", + "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nn_align4.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nn_align4.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nn_align4.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nn_align4.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nn_align4.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nn_align4.cu", + "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nn_align2.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nn_align2.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nn_align2.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nn_align2.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nn_align2.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nn_align2.cu", + "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nt_align8.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nt_align8.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nt_align8.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nt_align8.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nt_align8.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nt_align8.cu", + "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nt_align4.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nt_align4.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nt_align4.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nt_align4.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nt_align4.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nt_align4.cu", + "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nt_align2.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nt_align2.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nt_align2.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nt_align2.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nt_align2.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nt_align2.cu", + "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tn_align8.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tn_align8.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tn_align8.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tn_align8.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tn_align8.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tn_align8.cu", + "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tn_align4.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tn_align4.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tn_align4.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tn_align4.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tn_align4.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tn_align4.cu", + "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tn_align2.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tn_align2.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tn_align2.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tn_align2.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tn_align2.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tn_align2.cu", + "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tt_align8.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tt_align8.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tt_align8.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tt_align8.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tt_align8.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tt_align8.cu", + "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tt_align4.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tt_align4.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tt_align4.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tt_align4.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tt_align4.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tt_align4.cu", + "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tt_align2.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tt_align2.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tt_align2.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tt_align2.cu", + "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tt_align2.cu", + "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tt_align2.cu", + "cutlass_tensorop_h884gemm_256x128_32x2_nn_align8.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nn_align8.cu", + "cutlass_tensorop_h884gemm_128x256_32x2_nn_align8.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nn_align8.cu", + "cutlass_tensorop_h884gemm_128x128_32x2_nn_align8.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nn_align8.cu", + "cutlass_tensorop_h884gemm_256x128_32x2_nn_align4.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nn_align4.cu", + "cutlass_tensorop_h884gemm_128x256_32x2_nn_align4.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nn_align4.cu", + "cutlass_tensorop_h884gemm_128x128_32x2_nn_align4.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nn_align4.cu", + "cutlass_tensorop_h884gemm_256x128_32x2_nn_align2.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nn_align2.cu", + "cutlass_tensorop_h884gemm_128x256_32x2_nn_align2.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nn_align2.cu", + "cutlass_tensorop_h884gemm_128x128_32x2_nn_align2.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nn_align2.cu", + "cutlass_tensorop_h884gemm_256x128_32x2_nt_align8.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nt_align8.cu", + "cutlass_tensorop_h884gemm_128x256_32x2_nt_align8.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nt_align8.cu", + "cutlass_tensorop_h884gemm_128x128_32x2_nt_align8.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nt_align8.cu", + "cutlass_tensorop_h884gemm_256x128_32x2_nt_align4.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nt_align4.cu", + "cutlass_tensorop_h884gemm_128x256_32x2_nt_align4.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nt_align4.cu", + "cutlass_tensorop_h884gemm_128x128_32x2_nt_align4.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nt_align4.cu", + "cutlass_tensorop_h884gemm_256x128_32x2_nt_align2.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nt_align2.cu", + "cutlass_tensorop_h884gemm_128x256_32x2_nt_align2.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nt_align2.cu", + "cutlass_tensorop_h884gemm_128x128_32x2_nt_align2.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nt_align2.cu", + "cutlass_tensorop_h884gemm_256x128_32x2_tn_align8.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tn_align8.cu", + "cutlass_tensorop_h884gemm_128x256_32x2_tn_align8.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tn_align8.cu", + "cutlass_tensorop_h884gemm_128x128_32x2_tn_align8.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tn_align8.cu", + "cutlass_tensorop_h884gemm_256x128_32x2_tn_align4.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tn_align4.cu", + "cutlass_tensorop_h884gemm_128x256_32x2_tn_align4.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tn_align4.cu", + "cutlass_tensorop_h884gemm_128x128_32x2_tn_align4.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tn_align4.cu", + "cutlass_tensorop_h884gemm_256x128_32x2_tn_align2.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tn_align2.cu", + "cutlass_tensorop_h884gemm_128x256_32x2_tn_align2.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tn_align2.cu", + "cutlass_tensorop_h884gemm_128x128_32x2_tn_align2.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tn_align2.cu", + "cutlass_tensorop_h884gemm_256x128_32x2_tt_align8.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tt_align8.cu", + "cutlass_tensorop_h884gemm_128x256_32x2_tt_align8.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tt_align8.cu", + "cutlass_tensorop_h884gemm_128x128_32x2_tt_align8.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tt_align8.cu", + "cutlass_tensorop_h884gemm_256x128_32x2_tt_align4.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tt_align4.cu", + "cutlass_tensorop_h884gemm_128x256_32x2_tt_align4.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tt_align4.cu", + "cutlass_tensorop_h884gemm_128x128_32x2_tt_align4.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tt_align4.cu", + "cutlass_tensorop_h884gemm_256x128_32x2_tt_align2.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tt_align2.cu", + "cutlass_tensorop_h884gemm_128x256_32x2_tt_align2.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tt_align2.cu", + "cutlass_tensorop_h884gemm_128x128_32x2_tt_align2.cu", + "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tt_align2.cu", + "all_gemm_tensorop884_operations.cu", "cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4.cu", "cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2.cu", "cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1.cu", @@ -646,4 +936,4 @@ cutlass_gen_list = [ "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", "all_conv2d_tensorop8832_operations.cu", -] \ No newline at end of file +] diff --git a/dnn/src/CMakeLists.txt b/dnn/src/CMakeLists.txt index 43a900ac..a2e9b9b2 100644 --- a/dnn/src/CMakeLists.txt +++ b/dnn/src/CMakeLists.txt @@ -151,6 +151,8 @@ if(MGE_WITH_CUDA) set(${gen_files} "${${gen_files}}" PARENT_SCOPE) endfunction() gen_cutlass_kimpl(gemm simt CUTLASS_SOURCES) + gen_cutlass_kimpl(gemm tensorop884 CUTLASS_SOURCES) + gen_cutlass_kimpl(gemm tensorop1688 CUTLASS_SOURCES) gen_cutlass_kimpl(gemv simt CUTLASS_SOURCES) gen_cutlass_kimpl(deconv simt CUTLASS_SOURCES) gen_cutlass_kimpl(conv2d simt CUTLASS_SOURCES) diff --git a/dnn/src/cuda/cutlass/initialize_all.cu b/dnn/src/cuda/cutlass/initialize_all.cu index 1e67822d..e836dd76 100644 --- a/dnn/src/cuda/cutlass/initialize_all.cu +++ b/dnn/src/cuda/cutlass/initialize_all.cu @@ -49,6 +49,8 @@ namespace library { (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) void initialize_all_gemm_simt_operations(Manifest& manifest); +void initialize_all_gemm_tensorop884_operations(Manifest& manifest); +void initialize_all_gemm_tensorop1688_operations(Manifest& manifest); void initialize_all_conv2d_simt_operations(Manifest& manifest); void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest); void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest); @@ -56,6 +58,8 @@ void initialize_all_deconv_simt_operations(Manifest& manifest); void initialize_all(Manifest& manifest) { initialize_all_gemm_simt_operations(manifest); + initialize_all_gemm_tensorop884_operations(manifest); + initialize_all_gemm_tensorop1688_operations(manifest); initialize_all_conv2d_simt_operations(manifest); initialize_all_conv2d_tensorop8816_operations(manifest); initialize_all_conv2d_tensorop8832_operations(manifest); diff --git a/dnn/src/cuda/cutlass/operation_table.cpp b/dnn/src/cuda/cutlass/operation_table.cpp index a4858fe8..16655918 100644 --- a/dnn/src/cuda/cutlass/operation_table.cpp +++ b/dnn/src/cuda/cutlass/operation_table.cpp @@ -55,6 +55,8 @@ GemmKey get_gemm_key_from_desc(const GemmDescription& desc) { key.layout_B = desc.B.layout; key.element_C = desc.C.element; key.layout_C = desc.C.layout; + key.element_accumulator = + desc.tile_description.math_instruction.element_accumulator; key.threadblock_shape_m = desc.tile_description.threadblock_shape.m(); key.threadblock_shape_n = desc.tile_description.threadblock_shape.n(); @@ -75,6 +77,8 @@ GemmKey get_gemm_key_from_desc(const GemmDescription& desc) { desc.tile_description.math_instruction.instruction_shape.k(); key.stages = desc.stages; + key.alignment_A = desc.A.alignment; + key.alignment_B = desc.B.alignment; key.split_k_mode = desc.split_k_mode; return key; diff --git a/dnn/src/cuda/cutlass/operation_table.h b/dnn/src/cuda/cutlass/operation_table.h index 6190ff56..420e1e55 100644 --- a/dnn/src/cuda/cutlass/operation_table.h +++ b/dnn/src/cuda/cutlass/operation_table.h @@ -77,6 +77,7 @@ struct GemmKey { LayoutTypeID layout_B; NumericTypeID element_C; LayoutTypeID layout_C; + NumericTypeID element_accumulator; int threadblock_shape_m; int threadblock_shape_n; @@ -91,12 +92,15 @@ struct GemmKey { int instruction_shape_k; int stages; + int alignment_A; + int alignment_B; SplitKMode split_k_mode; inline bool operator==(GemmKey const& rhs) const { return (element_A == rhs.element_A) && (layout_A == rhs.layout_A) && (element_B == rhs.element_B) && (layout_B == rhs.layout_B) && (element_C == rhs.element_C) && (layout_C == rhs.layout_C) && + (element_accumulator == rhs.element_accumulator) && (threadblock_shape_m == rhs.threadblock_shape_m) && (threadblock_shape_n == rhs.threadblock_shape_n) && (threadblock_shape_k == rhs.threadblock_shape_k) && @@ -106,7 +110,9 @@ struct GemmKey { (instruction_shape_m == rhs.instruction_shape_m) && (instruction_shape_n == rhs.instruction_shape_n) && (instruction_shape_k == rhs.instruction_shape_k) && - (stages == rhs.stages) && (split_k_mode == rhs.split_k_mode); + (stages == rhs.stages) && (alignment_A == rhs.alignment_A) && + (alignment_B == rhs.alignment_B) && + (split_k_mode == rhs.split_k_mode); } inline bool operator!=(GemmKey const& rhs) const { return !(*this == rhs); } @@ -130,10 +136,13 @@ struct GemmKey { "\n layout_B: " + to_string(layout_B) + "\n element_C: " + to_string(element_C) + "\n layout_C: " + to_string(layout_C) + + "\n element_accumulator: " + to_string(element_accumulator) + "\n threadblock_shape: " + threadblock_shape_str + "\n warp_shape: " + warp_shape_str + "\n instruction_shape: " + instruction_shape_str + "\n stages: " + std::to_string(stages) + + "\n alignment_A: " + std::to_string(alignment_A) + + "\n alignment_B: " + std::to_string(alignment_B) + "\n split_k_mode: " + to_string(split_k_mode) + "\n}"; } }; @@ -147,6 +156,8 @@ struct GemmKeyHasher { .update(&key.layout_B, sizeof(key.layout_B)) .update(&key.element_C, sizeof(key.element_C)) .update(&key.layout_C, sizeof(key.layout_C)) + .update(&key.element_accumulator, + sizeof(key.element_accumulator)) .update(&key.threadblock_shape_m, sizeof(key.threadblock_shape_m)) .update(&key.threadblock_shape_n, @@ -157,6 +168,8 @@ struct GemmKeyHasher { .update(&key.warp_shape_n, sizeof(key.warp_shape_n)) .update(&key.warp_shape_k, sizeof(key.warp_shape_k)) .update(&key.stages, sizeof(key.stages)) + .update(&key.alignment_A, sizeof(key.alignment_A)) + .update(&key.alignment_B, sizeof(key.alignment_B)) .update(&key.split_k_mode, sizeof(key.split_k_mode)) .digest(); } diff --git a/dnn/src/cuda/matrix_mul/algos.cpp b/dnn/src/cuda/matrix_mul/algos.cpp index c9c01692..e3bb328d 100644 --- a/dnn/src/cuda/matrix_mul/algos.cpp +++ b/dnn/src/cuda/matrix_mul/algos.cpp @@ -43,6 +43,12 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { for (auto&& algo : simt_float32_gemv_batched_strided) { all_algos.push_back(&algo); } + for (auto&& algo : tensorop_float16) { + all_algos.push_back(&algo); + } + for (auto&& algo : tensorop_float16_split_k) { + all_algos.push_back(&algo); + } #endif all_algos.push_back(&naive); @@ -53,7 +59,7 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { #if CUDA_VERSION >= 9020 void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() { - using AlgoParam = AlgoFloat32SIMT::AlgoParam; + using AlgoParam = AlgoCutlassMatrixMulBase::AlgoParam; simt_float32.emplace_back(AlgoParam{64, 256, 8, 32, 64, 8}); simt_float32.emplace_back(AlgoParam{256, 64, 8, 64, 32, 8}); simt_float32.emplace_back(AlgoParam{32, 256, 8, 16, 64, 8}); @@ -91,6 +97,19 @@ void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() { simt_float32_gemv_batched_strided.emplace_back(128); simt_float32_gemv_batched_strided.emplace_back(64); simt_float32_gemv_batched_strided.emplace_back(32); +#define FOREACH_CUTLASS_MATMUL_F16_SHAPES(cb) \ + cb(256, 128, 32, 64, 64, 32, 8, 8, 4); \ + cb(128, 256, 32, 64, 64, 32, 8, 8, 4); \ + cb(128, 128, 32, 64, 64, 32, 8, 8, 4); \ + cb(256, 128, 32, 64, 64, 32, 16, 8, 8); \ + cb(128, 256, 32, 64, 64, 32, 16, 8, 8); \ + cb(128, 128, 32, 64, 64, 32, 16, 8, 8); +#define cb(...) \ + tensorop_float16.emplace_back(AlgoParam{__VA_ARGS__}); \ + tensorop_float16_split_k.emplace_back(AlgoParam{__VA_ARGS__}); + FOREACH_CUTLASS_MATMUL_F16_SHAPES(cb) +#undef cb +#undef FOREACH_CUTLASS_MATMUL_F16_SHAPES } #endif diff --git a/dnn/src/cuda/matrix_mul/algos.h b/dnn/src/cuda/matrix_mul/algos.h index 27e3fb6f..34c7cbf3 100644 --- a/dnn/src/cuda/matrix_mul/algos.h +++ b/dnn/src/cuda/matrix_mul/algos.h @@ -41,11 +41,13 @@ public: CUDA_WMMA_UINT4X4X32, CUDA_CUBLASLT, CUDA_NAIVE, - CUDA_BFLOAT16, + CUDA_BFLOAT16, #if CUDA_VERSION >= 9020 - CUDA_FLOAT32_SIMT, - CUDA_FLOAT32_SIMT_SPLIT_K, - CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED, + CUDA_FLOAT32_SIMT, + CUDA_FLOAT32_SIMT_SPLIT_K, + CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED, + CUDA_FLOAT16_TENSOR_OP, + CUDA_FLOAT16_TENSOR_OP_SPLIT_K, #endif }; using Mapper = std::unordered_map; @@ -188,65 +190,83 @@ private: #endif #if CUDA_VERSION >= 9020 -class MatrixMulForwardImpl::AlgoFloat32SIMT final : public AlgoBase { +class MatrixMulForwardImpl::AlgoCutlassMatrixMulBase : public AlgoBase { public: struct AlgoParam { int threadblock_m, threadblock_n, threadblock_k; int warp_m, warp_n, warp_k; - std::string to_string() { - return ssprintf("%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n, - threadblock_k, warp_m, warp_n, warp_k); - } + int instruction_m, instruction_n, instruction_k; + AlgoParam(int threadblock_m_, int threadblock_n_, int threadblock_k_, + int warp_m_, int warp_n_, int warp_k_, int instruction_m_ = 1, + int instruction_n_ = 1, int instruction_k_ = 1) + : threadblock_m{threadblock_m_}, + threadblock_n{threadblock_n_}, + threadblock_k{threadblock_k_}, + warp_m{warp_m_}, + warp_n{warp_n_}, + warp_k{warp_k_}, + instruction_m{instruction_m_}, + instruction_n{instruction_n_}, + instruction_k{instruction_k_} {} + std::string to_string() const; }; + AlgoCutlassMatrixMulBase(AlgoParam algo_param) : m_algo_param{algo_param} {} + void exec(const ExecArgs& args) const override; + std::string param() const override { + std::string ret; + serialize_write_pod(m_algo_param, ret); + return ret; + } + +protected: + virtual int min_alignment_requirement() const = 0; + virtual void do_exec(const ExecArgs& args) const = 0; + std::pair construct_aligned_layouts( + const SizeArgs& args) const; + int max_alignment(const SizeArgs& args) const; + AlgoParam m_algo_param; +}; + +class MatrixMulForwardImpl::AlgoFloat32SIMT final + : public AlgoCutlassMatrixMulBase { +public: AlgoFloat32SIMT(AlgoParam algo_param) - : m_algo_param{algo_param}, + : AlgoCutlassMatrixMulBase{algo_param}, m_name{ssprintf("CUTLASS_FLOAT32_SIMT_%s", m_algo_param.to_string().c_str())} {} bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; const char* name() const override { return m_name.c_str(); } - void exec(const ExecArgs& args) const override; AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT) - std::string param() const override { - std::string ret; - serialize_write_pod(m_algo_param, ret); - return ret; - } - private: - AlgoParam m_algo_param; + void do_exec(const ExecArgs& args) const override; + int min_alignment_requirement() const override { return 1; } std::string m_name; }; -class MatrixMulForwardImpl::AlgoFloat32SIMTSplitK final : public AlgoBase { +class MatrixMulForwardImpl::AlgoFloat32SIMTSplitK final + : public AlgoCutlassMatrixMulBase { public: - using AlgoParam = MatrixMulForwardImpl::AlgoFloat32SIMT::AlgoParam; AlgoFloat32SIMTSplitK(AlgoParam algo_param) - : m_algo_param{algo_param}, + : AlgoCutlassMatrixMulBase{algo_param}, m_name{ssprintf("CUTLASS_FLOAT32_SIMT_SPLIT_K_%s", m_algo_param.to_string().c_str())} {} bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; const char* name() const override { return m_name.c_str(); } - void exec(const ExecArgs& args) const override; AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K) - std::string param() const override { - std::string ret; - serialize_write_pod(m_algo_param, ret); - return ret; - } - private: - AlgoParam m_algo_param; + void do_exec(const ExecArgs& args) const override; + int min_alignment_requirement() const override { return 1; } std::string m_name; }; @@ -276,6 +296,56 @@ private: int m_threadblock_n; std::string m_name; }; + +class MatrixMulForwardImpl::AlgoFloat16TensorOp final + : public AlgoCutlassMatrixMulBase { +public: + AlgoFloat16TensorOp(AlgoParam algo_param) + : AlgoCutlassMatrixMulBase{algo_param}, + m_name{ssprintf("CUTLASS_FLOAT16_TENSOR_OP_h%d%d%d_%s", + m_algo_param.instruction_m, + m_algo_param.instruction_n, + m_algo_param.instruction_k, + m_algo_param.to_string().c_str())} {} + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + const char* name() const override { return m_name.c_str(); } + AlgoAttribute attribute() const override { + return AlgoAttribute::REPRODUCIBLE; + } + MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT16_TENSOR_OP) + +private: + void do_exec(const ExecArgs& args) const override; + int min_alignment_requirement() const override { return 2; } + std::string m_name; +}; + +class MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK final + : public AlgoCutlassMatrixMulBase { +public: + AlgoFloat16TensorOpSplitK(AlgoParam algo_param) + : AlgoCutlassMatrixMulBase{algo_param}, + m_name{ssprintf("CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h%d%d%d_%s", + m_algo_param.instruction_m, + m_algo_param.instruction_n, + m_algo_param.instruction_k, + m_algo_param.to_string().c_str())} {} + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + const char* name() const override { return m_name.c_str(); } + AlgoAttribute attribute() const override { + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + } + MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT16_TENSOR_OP_SPLIT_K) + +private: + void do_exec(const ExecArgs& args) const override; + int min_alignment_requirement() const override { return 2; } + std::string m_name; +}; + #endif class MatrixMulForwardImpl::AlgoPack : NonCopyableObj { @@ -300,6 +370,8 @@ public: std::vector simt_float32_split_k; std::vector simt_float32_gemv_batched_strided; + std::vector tensorop_float16; + std::vector tensorop_float16_split_k; #endif std::vector all_algos; diff --git a/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop.cpp b/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop.cpp new file mode 100644 index 00000000..ea2c05e9 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop.cpp @@ -0,0 +1,154 @@ +/** + * \file dnn/src/cuda/matrix_mul/cutlass_float16_tensorop.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/cuda/cutlass/singleton.h" +#include "src/cuda/handle.h" +#include "src/cuda/matrix_mul/algos.h" +#include "src/cuda/utils.h" + +#if CUDA_VERSION >= 9020 +using namespace megdnn; +using namespace cuda; + +bool MatrixMulForwardImpl::AlgoFloat16TensorOp::is_available( + const SizeArgs& args) const { + bool available = + args.opr->param().format == param::MatrixMul::Format::DEFAULT && + args.layout_b.dtype == dtype::Float16() && + args.layout_c.dtype == dtype::Float16(); + int n = args.layout_c.shape[1]; + auto&& device_prop = cuda::current_device_prop(); + int y_grid_limit = device_prop.maxGridSize[1]; + // limit y grid + available &= ((n + m_algo_param.threadblock_n - 1) / + m_algo_param.threadblock_n <= + y_grid_limit); + if (m_algo_param.instruction_m == 8 && m_algo_param.instruction_n == 8 && + m_algo_param.instruction_k == 4) { + available &= is_compute_capability_required(7, 0); + } else { + megdnn_assert(m_algo_param.instruction_m == 16 && + m_algo_param.instruction_n == 8 && + m_algo_param.instruction_k == 8); + available &= is_compute_capability_required(7, 5); + } + + return available; +} + +size_t MatrixMulForwardImpl::AlgoFloat16TensorOp::get_workspace_in_bytes( + const SizeArgs& args) const { + auto aligned = construct_aligned_layouts(args); + if (!aligned.first) + return 0_z; + const auto& layouts = aligned.second; + size_t ws_size = 0; + for (auto&& ly : layouts) { + ws_size += ly.span().dist_byte(); + } + return ws_size; +} + +void MatrixMulForwardImpl::AlgoFloat16TensorOp::do_exec( + const ExecArgs& args) const { + int64_t lda = args.tensor_a.layout.stride[0], + ldb = args.tensor_b.layout.stride[0], + ldc = args.tensor_c.layout.stride[0]; + int alignment = max_alignment(args); + int min_alignment = min_alignment_requirement(); + auto&& param = args.opr->param(); + int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], + k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; + megdnn_assert(lda % alignment == 0 && ldb % alignment == 0 && + ldc % alignment == 0 && m % alignment == 0 && + n % alignment == 0 && k % alignment == 0 && + alignment >= min_alignment); + cutlass::gemm::GemmCoord problem_size{m, n, k}; + auto&& stream = cuda_stream(args.opr->handle()); + int* workspace = reinterpret_cast(args.workspace.raw_ptr); + // \note these constants (i.e. one and zero) of cutlass epilogue will be + // passed by pointers and interpreted as ElementCompute*, which will be used + // to initialize kernel parameters. So the arguments' type on the host side + // should be the same as the ElementCompute of kernel instance, otherwise + // undefined kernel bahaviors will occur caused by incorrect intepretation + // of these pointers. + float one = 1.f, zero = 0.f; + dt_float16 one_f16 = static_cast(one), + zero_f16 = static_cast(zero); + + using namespace cutlass::library; + + auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor + : LayoutTypeID::kRowMajor; + auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor + : LayoutTypeID::kRowMajor; + + void *host_one, *host_zero; + NumericTypeID element_accumulator; + if (param.compute_mode == param::MatrixMul::ComputeMode::DEFAULT) { + element_accumulator = NumericTypeID::kF16; + host_one = &one_f16; + host_zero = &zero_f16; + } else { + megdnn_assert(param.compute_mode == + param::MatrixMul::ComputeMode::FLOAT32); + element_accumulator = NumericTypeID::kF32; + host_one = &one; + host_zero = &zero; + } + + GemmKey key{NumericTypeID::kF16, + layoutA, + NumericTypeID::kF16, + layoutB, + NumericTypeID::kF16, + LayoutTypeID::kRowMajor, + element_accumulator, + m_algo_param.threadblock_m, + m_algo_param.threadblock_n, + m_algo_param.threadblock_k, + m_algo_param.warp_m, + m_algo_param.warp_n, + m_algo_param.warp_k, + m_algo_param.instruction_m, + m_algo_param.instruction_n, + m_algo_param.instruction_k, + 2, + alignment, + alignment, + SplitKMode::kNone}; + + const auto& table = Singleton::get().operation_table; + megdnn_assert(table.gemm_operations.count(key) > 0, + "key not found in cutlass operation table"); + const auto& ops = table.gemm_operations.at(key); + megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", + ops.size()); + + GemmArguments gemm_args{problem_size, + args.tensor_a.raw_ptr, + args.tensor_b.raw_ptr, + args.tensor_c.raw_ptr, + args.tensor_c.raw_ptr, + lda, + ldb, + ldc, + ldc, + 1, + host_one, + host_zero}; + + cutlass_check(ops[0]->run(&gemm_args, workspace, stream)); +} +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp b/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp new file mode 100644 index 00000000..18211251 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp @@ -0,0 +1,165 @@ +/** + * \file dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/cuda/cutlass/singleton.h" +#include "src/cuda/handle.h" +#include "src/cuda/matrix_mul/algos.h" +#include "src/cuda/utils.h" + +#if CUDA_VERSION >= 9020 +using namespace megdnn; +using namespace cuda; + +bool MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::is_available( + const SizeArgs& args) const { + auto&& param = args.opr->param(); + int n = args.layout_c.shape[1], + k = args.layout_a.shape[param.transposeA ? 0 : 1]; + bool available = + args.opr->param().format == param::MatrixMul::Format::DEFAULT && + args.layout_a.dtype == dtype::Float16() && + args.layout_b.dtype == dtype::Float16() && + args.layout_c.dtype == dtype::Float16() && k > n; + auto&& device_prop = cuda::current_device_prop(); + int y_grid_limit = device_prop.maxGridSize[1]; + // limit y grid + available &= ((n + m_algo_param.threadblock_n - 1) / + m_algo_param.threadblock_n <= + y_grid_limit); + if (m_algo_param.instruction_m == 8 && m_algo_param.instruction_n == 8 && + m_algo_param.instruction_k == 4) { + available &= is_compute_capability_required(7, 0); + } else { + megdnn_assert(m_algo_param.instruction_m == 16 && + m_algo_param.instruction_n == 8 && + m_algo_param.instruction_k == 8); + available &= is_compute_capability_required(7, 5); + } + + return available; +} + +size_t MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::get_workspace_in_bytes( + const SizeArgs& args) const { + auto aligned = construct_aligned_layouts(args); + auto&& param = args.opr->param(); + int m = args.layout_c.shape[0], n = args.layout_c.shape[1], + k = args.layout_a.shape[param.transposeA ? 0 : 1]; + int split_k_slices = std::max(1, k / n); + if (!aligned.first) + return args.layout_c.dtype.size(m * n * split_k_slices); + const auto& layouts = aligned.second; + int align_m = layouts[2].shape[0], align_n = layouts[2].shape[1], + align_k = layouts[0].shape[1]; + split_k_slices = std::max(1, align_k / align_n); + size_t ws_size = + args.layout_c.dtype.size(align_m * align_n * split_k_slices); + for (auto&& ly : layouts) + ws_size += ly.span().dist_byte(); + return ws_size; +} + +void MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::do_exec( + const ExecArgs& args) const { + int64_t lda = args.tensor_a.layout.stride[0], + ldb = args.tensor_b.layout.stride[0], + ldc = args.tensor_c.layout.stride[0]; + int alignment = max_alignment(args); + int min_alignment = min_alignment_requirement(); + auto&& param = args.opr->param(); + int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], + k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; + megdnn_assert(lda % alignment == 0 && ldb % alignment == 0 && + ldc % alignment == 0 && m % alignment == 0 && + n % alignment == 0 && k % alignment == 0 && + alignment >= min_alignment); + cutlass::gemm::GemmCoord problem_size{m, n, k}; + int split_k_slices = std::max(1, k / n); + auto&& stream = cuda_stream(args.opr->handle()); + int* workspace = reinterpret_cast(args.workspace.raw_ptr); + // \note these constants (i.e. one and zero) of cutlass epilogue will be + // passed by pointers and interpreted as ElementCompute*, which will be used + // to initialize kernel parameters. So the arguments' type on the host side + // should be the same as the ElementCompute of kernel instance, otherwise + // undefined kernel bahaviors will occur caused by incorrect intepretation + // of these pointers. + float one = 1.f, zero = 0.f; + dt_float16 one_f16 = static_cast(one), + zero_f16 = static_cast(zero); + + using namespace cutlass::library; + + auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor + : LayoutTypeID::kRowMajor; + auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor + : LayoutTypeID::kRowMajor; + + void *host_one, *host_zero; + NumericTypeID element_accumulator; + if (param.compute_mode == param::MatrixMul::ComputeMode::DEFAULT) { + element_accumulator = NumericTypeID::kF16; + host_one = &one_f16; + host_zero = &zero_f16; + } else { + megdnn_assert(param.compute_mode == + param::MatrixMul::ComputeMode::FLOAT32); + element_accumulator = NumericTypeID::kF32; + host_one = &one; + host_zero = &zero; + } + + GemmKey key{NumericTypeID::kF16, + layoutA, + NumericTypeID::kF16, + layoutB, + NumericTypeID::kF16, + LayoutTypeID::kRowMajor, + element_accumulator, + m_algo_param.threadblock_m, + m_algo_param.threadblock_n, + m_algo_param.threadblock_k, + m_algo_param.warp_m, + m_algo_param.warp_n, + m_algo_param.warp_k, + m_algo_param.instruction_m, + m_algo_param.instruction_n, + m_algo_param.instruction_k, + 2, + alignment, + alignment, + SplitKMode::kParallel}; + + const auto& table = Singleton::get().operation_table; + megdnn_assert(table.gemm_operations.count(key) > 0, + "key not found in cutlass operation table"); + const auto& ops = table.gemm_operations.at(key); + megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", + ops.size()); + + GemmArguments gemm_args{problem_size, + args.tensor_a.raw_ptr, + args.tensor_b.raw_ptr, + args.tensor_c.raw_ptr, + args.tensor_c.raw_ptr, + lda, + ldb, + ldc, + ldc, + split_k_slices, + host_one, + host_zero}; + + cutlass_check(ops[0]->run(&gemm_args, workspace, stream)); +} +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp b/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp index 05fa960b..63d9faef 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp +++ b/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp @@ -42,7 +42,8 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes( return 0_z; } -void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { +void MatrixMulForwardImpl::AlgoFloat32SIMT::do_exec( + const ExecArgs& args) const { int64_t lda = args.tensor_a.layout.stride[0], ldb = args.tensor_b.layout.stride[0], ldc = args.tensor_c.layout.stride[0]; @@ -65,12 +66,14 @@ void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor; + int alignment = min_alignment_requirement(); GemmKey key{NumericTypeID::kF32, layoutA, NumericTypeID::kF32, layoutB, NumericTypeID::kF32, LayoutTypeID::kRowMajor, + NumericTypeID::kF32, m_algo_param.threadblock_m, m_algo_param.threadblock_n, m_algo_param.threadblock_k, @@ -79,8 +82,10 @@ void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { m_algo_param.warp_k, 1, 1, - 1, - 2, + 1, + 2, + alignment, + alignment, SplitKMode::kNone}; const Operation* op = Singleton::get().operation_table.find_op(key); diff --git a/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp b/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp index 2baa7f63..10ef7f42 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp +++ b/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp @@ -22,7 +22,7 @@ using namespace cuda; bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( const SizeArgs& args) const { auto&& param = args.opr->param(); - int m = args.layout_c.shape[0], n = args.layout_c.shape[1], + int n = args.layout_c.shape[1], k = args.layout_a.shape[param.transposeA ? 0 : 1]; bool available = args.opr->param().format == param::MatrixMul::Format::DEFAULT && @@ -32,8 +32,8 @@ bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( auto&& device_prop = cuda::current_device_prop(); int y_grid_limit = device_prop.maxGridSize[1]; // limit y grid - available &= ((m + m_algo_param.threadblock_m - 1) / - m_algo_param.threadblock_m <= + available &= ((n + m_algo_param.threadblock_n - 1) / + m_algo_param.threadblock_n <= y_grid_limit); return available; } @@ -47,7 +47,7 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( return args.layout_c.dtype.size(m * n * split_k_slices); } -void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( +void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::do_exec( const ExecArgs& args) const { int64_t lda = args.tensor_a.layout.stride[0], ldb = args.tensor_b.layout.stride[0], @@ -72,12 +72,14 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor; + int alignment = min_alignment_requirement(); GemmKey key{NumericTypeID::kF32, layoutA, NumericTypeID::kF32, layoutB, NumericTypeID::kF32, LayoutTypeID::kRowMajor, + NumericTypeID::kF32, m_algo_param.threadblock_m, m_algo_param.threadblock_n, m_algo_param.threadblock_k, @@ -87,7 +89,9 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( 1, 1, 1, - 2, + 2, + alignment, + alignment, SplitKMode::kParallel}; Operation const* op = Singleton::get().operation_table.find_op(key); diff --git a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_base.cpp b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_base.cpp new file mode 100644 index 00000000..55a95d1f --- /dev/null +++ b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_base.cpp @@ -0,0 +1,136 @@ +/** + * \file dnn/src/cuda/matrix_mul/cutlass_matrix_mul_base.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/cuda/handle.h" +#include "src/cuda/matrix_mul/algos.h" +#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" +#include "src/cuda/utils.h" + +#if CUDA_VERSION >= 9020 +using namespace megdnn; +using namespace cuda; + +std::string +MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::AlgoParam::to_string() const { + return ssprintf("%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n, + threadblock_k, warp_m, warp_n, warp_k); +} + +std::pair +MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::construct_aligned_layouts( + const SizeArgs& args) const { + int alignment = max_alignment(args); + int min_alignment = min_alignment_requirement(); + bool aligned = alignment >= min_alignment; + if (aligned) + return std::make_pair(!aligned, TensorLayoutArray{{}}); + auto&& param = args.opr->param(); + int m = args.layout_c.shape[0], n = args.layout_c.shape[1], + k = args.layout_a.shape[param.transposeA ? 0 : 1]; + size_t align_m = get_aligned_power2(m, min_alignment); + size_t align_n = get_aligned_power2(n, min_alignment); + size_t align_k = get_aligned_power2(k, min_alignment); + TensorLayoutArray layouts; + layouts.emplace_back(TensorLayout{{align_m, align_k}, args.layout_a.dtype}); + layouts.emplace_back(TensorLayout{{align_k, align_n}, args.layout_b.dtype}); + layouts.emplace_back(TensorLayout{{align_m, align_n}, args.layout_c.dtype}); + return std::make_pair(!aligned, std::move(layouts)); +} + +void MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::exec( + const ExecArgs& args) const { + auto aligned = construct_aligned_layouts(args); + if (!aligned.first) + return do_exec(args); + const auto& layouts = aligned.second; + auto tensor_a = args.tensor_a; + auto tensor_b = args.tensor_b; + auto workspace = args.workspace; + size_t copy_size = 0; + for (const auto& ly : layouts) + copy_size += ly.span().dist_byte(); + auto&& param = args.opr->param(); + auto&& stream = cuda_stream(args.opr->handle()); + + cuda_check(cudaMemsetAsync(workspace.raw_ptr, 0, copy_size, stream)); + + auto&& relayout = args.opr->handle()->create_operator(); + + auto copy_stride = [](const TensorLayout& src, TensorLayout& dst, + bool trans) { + dst.stride[0] = src.stride[0], dst.stride[1] = src.stride[1]; + if (trans) + std::swap(dst.stride[0], dst.stride[1]); + }; + copy_stride(layouts[0], tensor_a.layout, param.transposeA); + tensor_a.raw_ptr = workspace.raw_ptr; + relayout->exec(args.tensor_a, tensor_a); + workspace.raw_ptr += layouts[0].span().dist_byte(); + workspace.size -= layouts[0].span().dist_byte(); + + copy_stride(layouts[1], tensor_b.layout, param.transposeB); + tensor_b.raw_ptr = workspace.raw_ptr; + relayout->exec(args.tensor_b, tensor_b); + workspace.raw_ptr += layouts[1].span().dist_byte(); + workspace.size -= layouts[1].span().dist_byte(); + + decltype(tensor_a) tensor_c{workspace.raw_ptr, layouts[2]}; + workspace.raw_ptr += layouts[2].span().dist_byte(); + workspace.size -= layouts[2].span().dist_byte(); + + auto&& matmul = args.opr->handle()->create_operator(); + matmul->param().transposeA = false; + matmul->param().transposeB = false; + matmul->param().compute_mode = args.opr->param().compute_mode; + + tensor_a.layout = layouts[0]; + tensor_b.layout = layouts[1]; + ExecArgs args_{static_cast(matmul.get()), tensor_a, + tensor_b, tensor_c, workspace}; + do_exec(args_); + + tensor_c.layout.TensorShape::operator=(args.layout_c); + relayout->exec(tensor_c, args.tensor_c); +} + +int MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::max_alignment( + const SizeArgs& args) const { + auto&& dtype_a = args.layout_a.dtype; + auto&& dtype_b = args.layout_b.dtype; + auto&& dtype_c = args.layout_c.dtype; + auto get_alignment = [](const DType& dt, int len) { + int size_bits = dt.size(1) * 8; + int align = 128; + while (align > 1) { + if ((len * size_bits) % align == 0) + break; + align = align / 2; + } + return align / size_bits; + }; + int lda = args.layout_a.stride[0], ldb = args.layout_b.stride[0], + ldc = args.layout_c.stride[0]; + auto&& param = args.opr->param(); + int m = args.layout_c.shape[0], n = args.layout_c.shape[1], + k = args.layout_a.shape[param.transposeA ? 0 : 1]; + int max_align = get_alignment(dtype_a, lda); + max_align = std::min(get_alignment(dtype_a, m), max_align); + max_align = std::min(get_alignment(dtype_a, n), max_align); + max_align = std::min(get_alignment(dtype_a, k), max_align); + max_align = std::min(get_alignment(dtype_a, lda), max_align); + max_align = std::min(get_alignment(dtype_b, ldb), max_align); + max_align = std::min(get_alignment(dtype_c, ldc), max_align); + return max_align; +} +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/matrix_mul/opr_impl.h b/dnn/src/cuda/matrix_mul/opr_impl.h index b5192038..d4de4fc6 100644 --- a/dnn/src/cuda/matrix_mul/opr_impl.h +++ b/dnn/src/cuda/matrix_mul/opr_impl.h @@ -42,9 +42,12 @@ public: class AlgoBFloat16; #endif #if CUDA_VERSION >= 9020 + class AlgoCutlassMatrixMulBase; class AlgoFloat32SIMT; class AlgoFloat32SIMTSplitK; class AlgoFloat32SIMTGemvBatchedStrided; + class AlgoFloat16TensorOp; + class AlgoFloat16TensorOpSplitK; #endif class AlgoPack; diff --git a/dnn/test/common/matrix_mul.cpp b/dnn/test/common/matrix_mul.cpp index e703fce1..7e5e6b67 100644 --- a/dnn/test/common/matrix_mul.cpp +++ b/dnn/test/common/matrix_mul.cpp @@ -184,7 +184,8 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, const ExecutionPolicyAlgoName& algo, param::MatrixMul::Format format, size_t nbase, float eps, std::vector&& user_args, - bool force_deduce_dst) { + bool force_deduce_dst, + param::MatrixMul::ComputeMode compute_mode) { megdnn_assert(A_dtype.enumv() == B_dtype.enumv()); Checker checker(handle); checker.set_force_deduce_dst(force_deduce_dst); @@ -261,6 +262,7 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, Param param; param.transposeA = arg.mask & 0x1; param.transposeB = arg.mask & 0x2; + param.compute_mode = compute_mode; param.format = format; checker.set_dtype(0, A_dtype) .set_dtype(1, B_dtype) diff --git a/dnn/test/common/matrix_mul.h b/dnn/test/common/matrix_mul.h index 7c6da529..49aa8fea 100644 --- a/dnn/test/common/matrix_mul.h +++ b/dnn/test/common/matrix_mul.h @@ -69,7 +69,9 @@ void check_matrix_mul( const ExecutionPolicyAlgoName& algo = {"", {}}, param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, size_t nbase = 8, float eps = 1e-3, std::vector&& args = {}, - bool force_deduce_dst = true); + bool force_deduce_dst = true, + param::MatrixMul::ComputeMode compute_mode = + param::MatrixMul::ComputeMode::DEFAULT); void check_matrix_mul( DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, diff --git a/dnn/test/cuda/cutlass_matmul.cpp b/dnn/test/cuda/cutlass_matmul.cpp index 03600d34..b406fcf8 100644 --- a/dnn/test/cuda/cutlass_matmul.cpp +++ b/dnn/test/cuda/cutlass_matmul.cpp @@ -21,6 +21,7 @@ #include "test/cuda/fixture.h" #include "test/cuda/utils.h" +#define MEGDNN_WITH_BENCHMARK 1 #if CUDA_VERSION >= 9020 namespace megdnn { namespace test { @@ -215,6 +216,14 @@ std::vector get_feat_model_args() { return args; } +std::vector get_f16_feat_model_args() { + std::vector args; + args.emplace_back(BenchArgs{128, 9216, 9216}); + args.emplace_back(BenchArgs{128, 6400, 6400}); + args.emplace_back(BenchArgs{128, 5184, 5184}); + return args; +} + void benchmark_matrix_mul( Handle* handle, const std::vector& args, DType A_dtype, DType B_dtype, DType C_dtype, const char* algo = nullptr, @@ -364,6 +373,82 @@ MEGDNN_FOREACH_CUTLASS_KERNEL(cb) #undef cb #undef MEGDNN_FOREACH_CUTLASS_KERNEL +#define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \ + cb(1, 256, 128, 32, 64, 64, 32, 8, 8, 4); \ + cb(2, 128, 256, 32, 64, 64, 32, 8, 8, 4); \ + cb(3, 128, 128, 32, 64, 64, 32, 8, 8, 4); + +#define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \ + TEST_F(CUDA, CUTLASS_F16_884_GEMM_##name) { \ + require_compute_capability(7, 0); \ + matrix_mul::check_matrix_mul( \ + dtype::Float16(), dtype::Float16(), dtype::Float16(), \ + handle_cuda(), \ + "CUTLASS_FLOAT16_TENSOR_OP_h" #im #in #ik "_" #tbm "X" #tbn \ + "X" #tbk "_" #wm "X" #wn "X" #wk, \ + param::MatrixMul::Format::DEFAULT, 8, 1e-2, \ + matrix_mul::get_matmul_args()); \ + } +MEGDNN_FOREACH_CUTLASS_KERNEL(cb) + +#undef cb + +#define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \ + TEST_F(CUDA, CUTLASS_F16_884_GEMM_SPLIT_K_##name) { \ + require_compute_capability(7, 0); \ + matrix_mul::check_matrix_mul( \ + dtype::Float16(), dtype::Float16(), dtype::Float16(), \ + handle_cuda(), \ + "CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h" #im #in #ik "_" #tbm \ + "X" #tbn "X" #tbk "_" #wm "X" #wn "X" #wk, \ + param::MatrixMul::Format::DEFAULT, 8, 1e-3, \ + matrix_mul::get_matmul_args_split_k(), true, \ + param::MatrixMul::ComputeMode::FLOAT32); \ + } +MEGDNN_FOREACH_CUTLASS_KERNEL(cb) + +#undef cb + +#undef MEGDNN_FOREACH_CUTLASS_KERNEL + +#define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \ + cb(1, 256, 128, 32, 64, 64, 32, 16, 8, 8); \ + cb(2, 128, 256, 32, 64, 64, 32, 16, 8, 8); \ + cb(3, 128, 128, 32, 64, 64, 32, 16, 8, 8); + +#define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \ + TEST_F(CUDA, CUTLASS_F16_1688_GEMM_##name) { \ + require_compute_capability(7, 5); \ + matrix_mul::check_matrix_mul( \ + dtype::Float16(), dtype::Float16(), dtype::Float16(), \ + handle_cuda(), \ + "CUTLASS_FLOAT16_TENSOR_OP_h" #im #in #ik "_" #tbm "X" #tbn \ + "X" #tbk "_" #wm "X" #wn "X" #wk, \ + param::MatrixMul::Format::DEFAULT, 8, 1e-2, \ + matrix_mul::get_matmul_args(), true, \ + param::MatrixMul::ComputeMode::FLOAT32); \ + } +MEGDNN_FOREACH_CUTLASS_KERNEL(cb) + +#undef cb + +#define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \ + TEST_F(CUDA, CUTLASS_F16_1688_GEMM_SPLIT_K_##name) { \ + require_compute_capability(7, 5); \ + matrix_mul::check_matrix_mul( \ + dtype::Float16(), dtype::Float16(), dtype::Float16(), \ + handle_cuda(), \ + "CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h" #im #in #ik "_" #tbm \ + "X" #tbn "X" #tbk "_" #wm "X" #wn "X" #wk, \ + param::MatrixMul::Format::DEFAULT, 8, 1e-3, \ + matrix_mul::get_matmul_args_split_k()); \ + } +MEGDNN_FOREACH_CUTLASS_KERNEL(cb) + +#undef cb + +#undef MEGDNN_FOREACH_CUTLASS_KERNEL + #if MEGDNN_WITH_BENCHMARK TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL) { benchmark_matrix_mul(handle_cuda(), get_square_matmul_args(), @@ -376,6 +461,12 @@ TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL_FEAT) { dtype::Float32(), dtype::Float32(), "CUTLASS_FLOAT32_SIMT"); } + +TEST_F(CUDA, BENCHMARK_CUTLASS_F16_MATMUL_FEAT) { + benchmark_matrix_mul(handle_cuda(), get_f16_feat_model_args(), + dtype::Float16(), dtype::Float16(), dtype::Float16(), + "CUTLASS_FLOAT16_TENSOR_OP"); +} #endif } // namespace test } // namespace megdnn