From 47fe76631080614b51ebb47bab900a61eed2598f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 14 Feb 2022 16:47:49 +0800 Subject: [PATCH] feat(dnn/cuda): add implicit bmm kernels for large kernel depthwise convolution backward filter opr GitOrigin-RevId: 932e7689e89f2864884546935ca13f656842f41c --- dnn/scripts/cutlass_generator/BUILD | 2 + dnn/scripts/cutlass_generator/conv2d_operation.py | 129 ++++++++++++++-- dnn/scripts/cutlass_generator/gen_list.py | 7 +- dnn/scripts/cutlass_generator/generator.py | 94 ++++++++--- dnn/scripts/cutlass_generator/library.py | 1 + dnn/scripts/cutlass_generator/list.bzl | 58 ++++++- dnn/src/CMakeLists.txt | 2 + dnn/src/cuda/conv_bias/algo.cpp | 2 +- .../implicit_batched_gemm_float16_nchw_hmma.cpp | 1 + .../implicit_batched_gemm_float32_nchw_fma.cpp | 1 + dnn/src/cuda/conv_bias/opr_impl.cpp | 4 +- dnn/src/cuda/convolution/backward_data/algo.cpp | 2 +- .../implicit_batched_gemm_float16_nchw_hmma.cpp | 5 +- .../implicit_batched_gemm_float32_nchw_fma.cpp | 5 +- dnn/src/cuda/convolution/backward_filter/algo.cpp | 34 ++++ dnn/src/cuda/convolution/backward_filter/algo.h | 81 ++++++++++ .../implicit_batched_gemm_float16_nchw_hmma.cpp | 172 +++++++++++++++++++++ .../implicit_batched_gemm_float32_nchw_fma.cpp | 135 ++++++++++++++++ dnn/src/cuda/convolution/opr_impl.cpp | 26 +++- dnn/src/cuda/convolution/opr_impl.h | 2 + dnn/src/cuda/cutlass/convolution_operation.h | 162 +++++++++++++++++++ dnn/src/cuda/cutlass/initialize_all.cu | 25 ++- dnn/src/cuda/cutlass/operation_table.h | 1 - dnn/src/cuda/cutlass/util.cu | 2 + dnn/src/cuda/matrix_mul/algos.cpp | 23 +-- dnn/src/cuda/matrix_mul/algos.h | 4 +- .../cuda/matrix_mul/cutlass_float16_tensorop.cpp | 2 +- .../cutlass_float16_tensorop_split_k.cpp | 2 +- dnn/src/cuda/matrix_mul/opr_impl.h | 2 + dnn/test/cuda/chanwise_convolution.cpp | 117 ++++++++++++++ dnn/test/cuda/conv_bias.cpp | 3 + dnn/test/cuda/cutlass_matmul.cpp | 8 +- 32 files changed, 1043 insertions(+), 71 deletions(-) create mode 100644 dnn/src/cuda/convolution/backward_filter/implicit_batched_gemm_float16_nchw_hmma.cpp create mode 100644 dnn/src/cuda/convolution/backward_filter/implicit_batched_gemm_float32_nchw_fma.cpp diff --git a/dnn/scripts/cutlass_generator/BUILD b/dnn/scripts/cutlass_generator/BUILD index 64e61884..d90fc438 100644 --- a/dnn/scripts/cutlass_generator/BUILD +++ b/dnn/scripts/cutlass_generator/BUILD @@ -17,6 +17,8 @@ genrule( CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_fprop --type tensorop884 $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_dgrad --type simt $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_dgrad --type tensorop884 $(@D) + CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_wgrad --type simt $(@D) + CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_wgrad --type tensorop884 $(@D) """, tools = ["//brain/megbrain/dnn/scripts/cutlass_generator:generator.py"], visibility = ["//visibility:public"], diff --git a/dnn/scripts/cutlass_generator/conv2d_operation.py b/dnn/scripts/cutlass_generator/conv2d_operation.py index 89b6d258..cc9c7230 100644 --- a/dnn/scripts/cutlass_generator/conv2d_operation.py +++ b/dnn/scripts/cutlass_generator/conv2d_operation.py @@ -317,7 +317,7 @@ class EmitDeconvInstance: def __init__(self): self.template = """ // kernel instance "${operation_name}" generated by cutlass generator -using Deconvolution = +using Convolution = typename cutlass::conv::device::Deconvolution< ${element_src}, ${layout_src}, @@ -415,6 +415,103 @@ using Deconvolution = return SubstituteTemplate(self.template, values) +class EmitConvolutionBackwardFilterInstance: + def __init__(self): + self.template = """ +// kernel instance "${operation_name}" generated by cutlass generator +using Convolution = + typename cutlass::conv::device::ConvolutionBackwardFilter< + ${element_src}, + ${layout_src}, + ${element_diff}, + ${layout_diff}, + ${element_grad}, + ${layout_grad}, + ${element_accumulator}, + ${conv_type}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_grad}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${alignment_src}, + ${alignment_diff}, + ${special_optimization}, + ${math_operator}, + ${implicit_gemm_mode}>; +""" + + def emit(self, operation): + + warp_shape = [ + int( + operation.tile_description.threadblock_shape[idx] + / operation.tile_description.warp_count[idx] + ) + for idx in range(3) + ] + + epilogue_vector_length = int( + min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128) + / DataTypeSize[operation.dst.element] + ) + + values = { + "operation_name": operation.procedural_name(), + "conv_type": ConvTypeTag[operation.conv_type], + "element_src": DataTypeTag[operation.src.element], + "layout_src": LayoutTag[operation.src.layout], + "element_diff": DataTypeTag[operation.flt.element], + "layout_diff": LayoutTag[operation.flt.layout], + "element_grad": DataTypeTag[operation.dst.element], + "layout_grad": LayoutTag[operation.dst.layout], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class": OpcodeClassTag[ + operation.tile_description.math_instruction.opcode_class + ], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str( + operation.tile_description.math_instruction.instruction_shape[0] + ), + "instruction_shape_n": str( + operation.tile_description.math_instruction.instruction_shape[1] + ), + "instruction_shape_k": str( + operation.tile_description.math_instruction.instruction_shape[2] + ), + "epilogue_vector_length": str(epilogue_vector_length), + "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor], + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), + "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor], + "stages": str(operation.tile_description.stages), + "alignment_src": str(operation.src.alignment), + "alignment_diff": str(operation.flt.alignment), + "special_optimization": SpecialOptimizeDescTag[ + operation.special_optimization + ], + "math_operator": MathOperationTag[ + operation.tile_description.math_instruction.math_operation + ], + "implicit_gemm_mode": ImplicitGemmModeTag[operation.implicit_gemm_mode], + } + + return SubstituteTemplate(self.template, values) + + ################################################################################################### # # Generator functions for all layouts @@ -500,6 +597,7 @@ def GenerateConv2d( epilogues = [ EpilogueFunctor.BiasAddLinearCombination, EpilogueFunctor.BiasAddLinearCombinationRelu, + EpilogueFunctor.LinearCombination, ] if conv_type == ConvType.Convolution: epilogues.append(EpilogueFunctor.BiasAddLinearCombinationHSwish) @@ -544,11 +642,15 @@ def GenerateConv2d( def filter_epilogue_with_conv_kind( epilogue: EpilogueFunctor, conv_kind: ConvKind ) -> bool: - return ( - (conv_kind == ConvKind.Dgrad or conv_kind == ConvKind.Wgrad) - and epilogue != EpilogueFunctor.BiasAddLinearCombinationClamp - and epilogue != EpilogueFunctor.BiasAddLinearCombination - ) + if conv_kind == ConvKind.Fprop: + return epilogue == EpilogueFunctor.LinearCombination + elif conv_kind == ConvKind.Dgrad: + return ( + epilogue != EpilogueFunctor.BiasAddLinearCombinationClamp + and epilogue != EpilogueFunctor.BiasAddLinearCombination + ) + elif conv_kind == ConvKind.Wgrad: + return epilogue != EpilogueFunctor.LinearCombination # loop over all tile descriptions for tile in tile_descriptions: @@ -557,7 +659,7 @@ def GenerateConv2d( bias_type, epilogues = get_bias_type_and_epilogues(tile, dst_type) - flt_align = get_flt_align(tile) + flt_align = flt_align if conv_kind == ConvKind.Wgrad else get_flt_align(tile) dst_align = get_dst_align(tile, dst_layout) @@ -771,11 +873,14 @@ class EmitConvSingleKernelWrapper: if self.operation.conv_kind == ConvKind.Fprop: self.instance_emitter = EmitConv2dInstance() - self.convolution_name = "Convolution" - else: - assert self.operation.conv_kind == ConvKind.Dgrad + self.convolution_name = "ConvolutionOperation" + elif self.operation.conv_kind == ConvKind.Dgrad: self.instance_emitter = EmitDeconvInstance() - self.convolution_name = "Deconvolution" + self.convolution_name = "ConvolutionOperation" + else: + assert self.operation.conv_kind == ConvKind.Wgrad + self.instance_emitter = EmitConvolutionBackwardFilterInstance() + self.convolution_name = "ConvolutionBackwardFilterOperation" self.header_template = """ #if __CUDACC_VER_MAJOR__ > ${required_cuda_ver_major} || (__CUDACC_VER_MAJOR__ == ${required_cuda_ver_major} && __CUDACC_VER_MINOR__ >= ${required_cuda_ver_minor}) @@ -800,7 +905,7 @@ namespace cutlass { namespace library { void initialize_${operation_name}(Manifest &manifest) { - manifest.append(new ConvolutionOperation<${convolution_name}>( + manifest.append(new ${convolution_name}( "${operation_name}" )); } diff --git a/dnn/scripts/cutlass_generator/gen_list.py b/dnn/scripts/cutlass_generator/gen_list.py index c6758695..11805fd0 100644 --- a/dnn/scripts/cutlass_generator/gen_list.py +++ b/dnn/scripts/cutlass_generator/gen_list.py @@ -3,8 +3,9 @@ from generator import ( GenerateGemvOperations, GenerateConv2dOperations, GenerateDeconvOperations, - GenerateDwconv2dFpropOperations, + GenerateDwconv2dFpropOperations, GenerateDwconv2dDgradOperations, + GenerateDwconv2dWgradOperations, ) @@ -28,7 +29,7 @@ def write_op_list(f, gen_op, gen_type): elif gen_op == "dwconv2d_dgrad": operations = GenerateDwconv2dDgradOperations(GenArg(gen_op, gen_type)) elif gen_op == "dwconv2d_wgrad": - pass + operations = GenerateDwconv2dWgradOperations(GenArg(gen_op, gen_type)) for op in operations: f.write(' "%s.cu",\n' % op.procedural_name()) if gen_op != "gemv": @@ -52,4 +53,6 @@ if __name__ == "__main__": write_op_list(f, "dwconv2d_fprop", "tensorop884") write_op_list(f, "dwconv2d_dgrad", "simt") write_op_list(f, "dwconv2d_dgrad", "tensorop884") + write_op_list(f, "dwconv2d_wgrad", "simt") + write_op_list(f, "dwconv2d_wgrad", "tensorop884") f.write("]") diff --git a/dnn/scripts/cutlass_generator/generator.py b/dnn/scripts/cutlass_generator/generator.py index 68c2dfc5..45838134 100644 --- a/dnn/scripts/cutlass_generator/generator.py +++ b/dnn/scripts/cutlass_generator/generator.py @@ -1115,7 +1115,10 @@ def GenerateDwconv2d_Simt(args, conv_kind): dst_types = [DataType.f32] - alignment_constraints = [128, 32] + if conv_kind == ConvKind.Wgrad: + alignment_constraints = [32] + else: + alignment_constraints = [128, 32] operations = [] for math_inst in math_instructions: @@ -1244,7 +1247,9 @@ def GenerateDwconv2d_Simt(args, conv_kind): 32, 32, SpecialOptimizeDesc.NoneSpecialOpt, - ImplicitGemmMode.GemmTN, + ImplicitGemmMode.GemmNT + if conv_kind == ConvKind.Wgrad + else ImplicitGemmMode.GemmTN, ) return operations @@ -1277,11 +1282,14 @@ def GenerateDwconv2d_TensorOp_884(args, conv_kind): dst_layouts = [LayoutType.TensorNCHW] - dst_types = [DataType.f16] + if conv_kind == ConvKind.Wgrad: + dst_types = [DataType.f32] + else: + dst_types = [DataType.f16] alignment_constraints = [128, 32, 16] cuda_major = 10 - cuda_minor = 2 + cuda_minor = 1 operations = [] for math_inst in math_instructions: @@ -1295,24 +1303,48 @@ def GenerateDwconv2d_TensorOp_884(args, conv_kind): for layout in layouts: for dst_type, dst_layout in zip(dst_types, dst_layouts): for alignment_src in alignment_constraints: - operations += GenerateConv2d( - ConvType.DepthwiseConvolution, - conv_kind, - tile_descriptions, - layout[0], - layout[1], - dst_layout, - dst_type, - min_cc, - alignment_src, - 16, - 16, - SpecialOptimizeDesc.NoneSpecialOpt, - ImplicitGemmMode.GemmTN, - False, - cuda_major, - cuda_minor, - ) + if conv_kind == ConvKind.Wgrad: + # skip io16xc16 + if math_inst.element_accumulator == DataType.f16: + continue + for alignment_diff in alignment_constraints: + operations += GenerateConv2d( + ConvType.DepthwiseConvolution, + conv_kind, + tile_descriptions, + layout[0], + layout[1], + dst_layout, + dst_type, + min_cc, + alignment_src, + alignment_diff, + 32, # always f32 output + SpecialOptimizeDesc.NoneSpecialOpt, + ImplicitGemmMode.GemmNT, + False, + cuda_major, + cuda_minor, + ) + else: + operations += GenerateConv2d( + ConvType.DepthwiseConvolution, + conv_kind, + tile_descriptions, + layout[0], + layout[1], + dst_layout, + dst_type, + min_cc, + alignment_src, + 16, + 16, + SpecialOptimizeDesc.NoneSpecialOpt, + ImplicitGemmMode.GemmTN, + False, + cuda_major, + cuda_minor, + ) return operations @@ -1501,7 +1533,7 @@ def GeneratesGemm_TensorOp_884(args): # 1 ] cuda_major = 10 - cuda_minor = 2 + cuda_minor = 1 operations = [] for math_inst in math_instructions: @@ -1595,6 +1627,17 @@ def GenerateDwconv2dDgradOperations(args): return GenerateDwconv2d_TensorOp_884(args, ConvKind.Dgrad) +def GenerateDwconv2dWgradOperations(args): + if args.type == "simt": + return GenerateDwconv2d_Simt(args, ConvKind.Wgrad) + else: + assert args.type == "tensorop884", ( + "operation dwconv2d fprop only support" + "simt, tensorop884. (got:{})".format(args.type) + ) + return GenerateDwconv2d_TensorOp_884(args, ConvKind.Wgrad) + + def GenerateGemmOperations(args): if args.type == "tensorop884": return GeneratesGemm_TensorOp_884(args) @@ -1668,8 +1711,9 @@ if __name__ == "__main__": operations = GenerateDwconv2dFpropOperations(args) elif args.operations == "dwconv2d_dgrad": operations = GenerateDwconv2dDgradOperations(args) - elif args.operations == "dwconv2d_wgrad": - pass + else: + assert args.operations == "dwconv2d_wgrad", "invalid operation" + operations = GenerateDwconv2dWgradOperations(args) if ( args.operations == "conv2d" diff --git a/dnn/scripts/cutlass_generator/library.py b/dnn/scripts/cutlass_generator/library.py index 60f2f3af..b1b669d8 100644 --- a/dnn/scripts/cutlass_generator/library.py +++ b/dnn/scripts/cutlass_generator/library.py @@ -483,6 +483,7 @@ EpilogueFunctorTag = { # ShortEpilogueNames = { + EpilogueFunctor.LinearCombination: "id", EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: "hswish", EpilogueFunctor.BiasAddLinearCombinationReluClamp: "relu", EpilogueFunctor.BiasAddLinearCombinationClamp: "id", diff --git a/dnn/scripts/cutlass_generator/list.bzl b/dnn/scripts/cutlass_generator/list.bzl index 76b877b8..683e092f 100644 --- a/dnn/scripts/cutlass_generator/list.bzl +++ b/dnn/scripts/cutlass_generator/list.bzl @@ -1382,4 +1382,60 @@ cutlass_gen_list = [ "cutlass_tensorop_h884dwdgrad_id_f16_128x64x32_32x32x32_2_nchw_nchw_align1x1.cu", "cutlass_tensorop_h884dwdgrad_id_f16_64x64x32_32x32x32_2_nchw_nchw_align1x1.cu", "all_dwconv2d_dgrad_tensorop884_operations.cu", -] + "cutlass_simt_sdwwgrad_id_f32_32x32x8_32x32x8_2_nchw_nchw_align1x1.cu", + "cutlass_simt_sdwwgrad_id_f32_32x64x8_32x64x8_2_nchw_nchw_align1x1.cu", + "cutlass_simt_sdwwgrad_id_f32_64x32x8_64x32x8_2_nchw_nchw_align1x1.cu", + "cutlass_simt_sdwwgrad_id_f32_32x128x8_32x64x8_2_nchw_nchw_align1x1.cu", + "cutlass_simt_sdwwgrad_id_f32_64x64x8_32x64x8_2_nchw_nchw_align1x1.cu", + "cutlass_simt_sdwwgrad_id_f32_128x32x8_64x32x8_2_nchw_nchw_align1x1.cu", + "cutlass_simt_sdwwgrad_id_f32_64x128x8_32x64x8_2_nchw_nchw_align1x1.cu", + "cutlass_simt_sdwwgrad_id_f32_128x64x8_64x32x8_2_nchw_nchw_align1x1.cu", + "cutlass_simt_sdwwgrad_id_f32_128x128x8_32x64x8_2_nchw_nchw_align1x1.cu", + "all_dwconv2d_wgrad_simt_operations.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x256x32_64x64x32_2_nchw_nchw_align8x8.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x128x32_32x32x32_2_nchw_nchw_align8x8.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_64x128x32_32x32x32_2_nchw_nchw_align8x8.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x64x32_32x32x32_2_nchw_nchw_align8x8.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_64x64x32_32x32x32_2_nchw_nchw_align8x8.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x256x32_64x64x32_2_nchw_nchw_align8x2.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x128x32_32x32x32_2_nchw_nchw_align8x2.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_64x128x32_32x32x32_2_nchw_nchw_align8x2.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x64x32_32x32x32_2_nchw_nchw_align8x2.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_64x64x32_32x32x32_2_nchw_nchw_align8x2.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x256x32_64x64x32_2_nchw_nchw_align8x1.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x128x32_32x32x32_2_nchw_nchw_align8x1.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_64x128x32_32x32x32_2_nchw_nchw_align8x1.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x64x32_32x32x32_2_nchw_nchw_align8x1.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_64x64x32_32x32x32_2_nchw_nchw_align8x1.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x256x32_64x64x32_2_nchw_nchw_align2x8.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x128x32_32x32x32_2_nchw_nchw_align2x8.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_64x128x32_32x32x32_2_nchw_nchw_align2x8.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x64x32_32x32x32_2_nchw_nchw_align2x8.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_64x64x32_32x32x32_2_nchw_nchw_align2x8.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x256x32_64x64x32_2_nchw_nchw_align2x2.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x128x32_32x32x32_2_nchw_nchw_align2x2.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_64x128x32_32x32x32_2_nchw_nchw_align2x2.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x64x32_32x32x32_2_nchw_nchw_align2x2.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_64x64x32_32x32x32_2_nchw_nchw_align2x2.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x256x32_64x64x32_2_nchw_nchw_align2x1.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x128x32_32x32x32_2_nchw_nchw_align2x1.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_64x128x32_32x32x32_2_nchw_nchw_align2x1.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x64x32_32x32x32_2_nchw_nchw_align2x1.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_64x64x32_32x32x32_2_nchw_nchw_align2x1.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x256x32_64x64x32_2_nchw_nchw_align1x8.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x128x32_32x32x32_2_nchw_nchw_align1x8.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_64x128x32_32x32x32_2_nchw_nchw_align1x8.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x64x32_32x32x32_2_nchw_nchw_align1x8.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_64x64x32_32x32x32_2_nchw_nchw_align1x8.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x256x32_64x64x32_2_nchw_nchw_align1x2.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x128x32_32x32x32_2_nchw_nchw_align1x2.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_64x128x32_32x32x32_2_nchw_nchw_align1x2.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x64x32_32x32x32_2_nchw_nchw_align1x2.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_64x64x32_32x32x32_2_nchw_nchw_align1x2.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x256x32_64x64x32_2_nchw_nchw_align1x1.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x128x32_32x32x32_2_nchw_nchw_align1x1.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_64x128x32_32x32x32_2_nchw_nchw_align1x1.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_128x64x32_32x32x32_2_nchw_nchw_align1x1.cu", + "cutlass_tensorop_s884dwwgrad_id_f16_64x64x32_32x32x32_2_nchw_nchw_align1x1.cu", + "all_dwconv2d_wgrad_tensorop884_operations.cu", +] \ No newline at end of file diff --git a/dnn/src/CMakeLists.txt b/dnn/src/CMakeLists.txt index 962ce874..2113839c 100644 --- a/dnn/src/CMakeLists.txt +++ b/dnn/src/CMakeLists.txt @@ -185,6 +185,8 @@ if(MGE_WITH_CUDA) gen_cutlass_kimpl(dwconv2d_fprop tensorop884 CUTLASS_SOURCES) gen_cutlass_kimpl(dwconv2d_dgrad simt CUTLASS_SOURCES) gen_cutlass_kimpl(dwconv2d_dgrad tensorop884 CUTLASS_SOURCES) + gen_cutlass_kimpl(dwconv2d_wgrad simt CUTLASS_SOURCES) + gen_cutlass_kimpl(dwconv2d_wgrad tensorop884 CUTLASS_SOURCES) list(APPEND SOURCES ${CUTLASS_SOURCES}) list(APPEND SOURCES ${CUSOURCES}) endif() diff --git a/dnn/src/cuda/conv_bias/algo.cpp b/dnn/src/cuda/conv_bias/algo.cpp index bc5bcb29..2cc0c327 100644 --- a/dnn/src/cuda/conv_bias/algo.cpp +++ b/dnn/src/cuda/conv_bias/algo.cpp @@ -317,7 +317,7 @@ void ConvBiasForwardImpl::AlgoPack::fill_dwconv_algos() { for (auto&& algo : f32_implicit_bmm) { all_algos.push_back(&algo); } -#if CUDA_VERSION >= 10020 +#if CUDA_VERSION >= 10010 /// preferred algo f16_implicit_bmm.emplace_back(AlgoParam{64, 128, 32, 32, 32, 32, 8, 8, 4, 2}); f16_implicit_bmm.emplace_back(AlgoParam{128, 128, 32, 32, 32, 32, 8, 8, 4, 2}); diff --git a/dnn/src/cuda/conv_bias/implicit_batched_gemm_float16_nchw_hmma.cpp b/dnn/src/cuda/conv_bias/implicit_batched_gemm_float16_nchw_hmma.cpp index 48c64c58..3aba41d7 100644 --- a/dnn/src/cuda/conv_bias/implicit_batched_gemm_float16_nchw_hmma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_batched_gemm_float16_nchw_hmma.cpp @@ -50,6 +50,7 @@ bool ConvBiasForwardImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm::is_available( RETURN_IF_FALSE(param.mode == Mode::CROSS_CORRELATION); // check if channelwise convolution RETURN_IF_FALSE(fm.icpg == 1 && fm.ocpg == 1); + RETURN_IF_FALSE(param.dilate_h == 1 && param.dilate_w == 1); const auto* op = get_cutlass_conv_op( args, ConvOperator::kFprop, ConvType::kDepthwiseConvolution, false, false); RETURN_IF_FALSE(op != nullptr); diff --git a/dnn/src/cuda/conv_bias/implicit_batched_gemm_float32_nchw_fma.cpp b/dnn/src/cuda/conv_bias/implicit_batched_gemm_float32_nchw_fma.cpp index 1e4fee11..cd40a1fd 100644 --- a/dnn/src/cuda/conv_bias/implicit_batched_gemm_float32_nchw_fma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_batched_gemm_float32_nchw_fma.cpp @@ -50,6 +50,7 @@ bool ConvBiasForwardImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm::is_available( RETURN_IF_FALSE(param.mode == Mode::CROSS_CORRELATION); // check if channelwise convolution RETURN_IF_FALSE(fm.icpg == 1 && fm.ocpg == 1); + RETURN_IF_FALSE(param.dilate_h == 1 && param.dilate_w == 1); const auto* op = get_cutlass_conv_op( args, ConvOperator::kFprop, ConvType::kDepthwiseConvolution, false, false); RETURN_IF_FALSE(op != nullptr); diff --git a/dnn/src/cuda/conv_bias/opr_impl.cpp b/dnn/src/cuda/conv_bias/opr_impl.cpp index 17ea6eb9..2cb9a04a 100644 --- a/dnn/src/cuda/conv_bias/opr_impl.cpp +++ b/dnn/src/cuda/conv_bias/opr_impl.cpp @@ -146,15 +146,17 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( args.filter_meta.stride[0] != 1 || args.filter_meta.stride[1] != 1 || hw_size < 512; //! choose for large kernel cases - size_t fh = args.filter_meta.spatial[2], fw = args.filter_meta.spatial[3]; + size_t fh = args.filter_meta.spatial[0], fw = args.filter_meta.spatial[1]; size_t hi = src[2], wi = src[3]; const bool prefer_dnn_lk_implbmm = hi <= 2 * fh && wi <= 2 * fw; //! avoid bad case in cudnn, check dnn chanwise impl first if (is_chanwise) { if (prefer_dnn_lk_implbmm) { +#if CUDA_VERSION >= 10020 if (sm_algo_pack.f16_implicit_bmm[0].is_available_attribute( args, positive_attr, negative_attr, workspace_limit_in_bytes)) return &sm_algo_pack.f16_implicit_bmm[0]; +#endif if (sm_algo_pack.f32_implicit_bmm[0].is_available_attribute( args, positive_attr, negative_attr, workspace_limit_in_bytes)) return &sm_algo_pack.f32_implicit_bmm[0]; diff --git a/dnn/src/cuda/convolution/backward_data/algo.cpp b/dnn/src/cuda/convolution/backward_data/algo.cpp index acd41aa6..7f090dfd 100644 --- a/dnn/src/cuda/convolution/backward_data/algo.cpp +++ b/dnn/src/cuda/convolution/backward_data/algo.cpp @@ -72,7 +72,7 @@ void ConvolutionBackwardDataImpl::AlgoPack::fill_dwconv_algos() { all_algos.push_back(&algo); } } -#if CUDA_VERSION >= 10020 +#if CUDA_VERSION >= 10010 { using AlgoParam = AlgoFloat16NCHWHMMAImplicitBatchedGemm::AlgoParam; /// preferred algo diff --git a/dnn/src/cuda/convolution/backward_data/implicit_batched_gemm_float16_nchw_hmma.cpp b/dnn/src/cuda/convolution/backward_data/implicit_batched_gemm_float16_nchw_hmma.cpp index f1ec797d..f26207ed 100644 --- a/dnn/src/cuda/convolution/backward_data/implicit_batched_gemm_float16_nchw_hmma.cpp +++ b/dnn/src/cuda/convolution/backward_data/implicit_batched_gemm_float16_nchw_hmma.cpp @@ -24,8 +24,10 @@ const void* ConvolutionBackwardDataImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm: int alignment_diff = 0; int wo = args.diff_layout->dtype.size(args.diff_layout->operator[](3)); for (int candidate : {16, 4, 2}) { - if (wo % candidate == 0) + if (wo % candidate == 0) { alignment_diff = candidate; + break; + } } alignment_diff /= args.diff_layout->dtype.size(1); NumericTypeID accumulator_dtype = @@ -85,6 +87,7 @@ bool ConvolutionBackwardDataImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm::is_ava RETURN_IF_FALSE(param.mode == Mode::CROSS_CORRELATION); // check if channelwise convolution RETURN_IF_FALSE(fm.icpg == 1 && fm.ocpg == 1); + RETURN_IF_FALSE(param.dilate_h == 1 && param.dilate_w == 1); const auto* op = get_available_op(args); RETURN_IF_FALSE(op != nullptr); return true; diff --git a/dnn/src/cuda/convolution/backward_data/implicit_batched_gemm_float32_nchw_fma.cpp b/dnn/src/cuda/convolution/backward_data/implicit_batched_gemm_float32_nchw_fma.cpp index 17bc1a0d..0039bb37 100644 --- a/dnn/src/cuda/convolution/backward_data/implicit_batched_gemm_float32_nchw_fma.cpp +++ b/dnn/src/cuda/convolution/backward_data/implicit_batched_gemm_float32_nchw_fma.cpp @@ -24,8 +24,10 @@ const void* ConvolutionBackwardDataImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm:: int alignment_diff = 0; int wo = args.diff_layout->dtype.size(args.diff_layout->operator[](3)); for (int candidate : {16, 4}) { - if (wo % candidate == 0) + if (wo % candidate == 0) { alignment_diff = candidate; + break; + } } alignment_diff /= args.diff_layout->dtype.size(1); ConvolutionKey key{ @@ -81,6 +83,7 @@ bool ConvolutionBackwardDataImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm::is_avai RETURN_IF_FALSE(param.mode == Mode::CROSS_CORRELATION); // check if channelwise convolution RETURN_IF_FALSE(fm.icpg == 1 && fm.ocpg == 1); + RETURN_IF_FALSE(param.dilate_h == 1 && param.dilate_w == 1); const auto* op = get_available_op(args); RETURN_IF_FALSE(op != nullptr); return true; diff --git a/dnn/src/cuda/convolution/backward_filter/algo.cpp b/dnn/src/cuda/convolution/backward_filter/algo.cpp index ba4c886c..5246714b 100644 --- a/dnn/src/cuda/convolution/backward_filter/algo.cpp +++ b/dnn/src/cuda/convolution/backward_filter/algo.cpp @@ -25,6 +25,7 @@ ConvolutionBackwardFilterImpl::AlgoPack::AlgoPack() { for (auto&& i : cudnn) { all_algos.push_back(&i); } + fill_dwconv_algos(); all_algos.push_back(&matmul); all_algos.push_back(&group); @@ -48,6 +49,39 @@ ConvolutionBackwardFilterImpl::AlgoCUDNN* ConvolutionBackwardFilterImpl::AlgoPac "can not find cudnn bwd_filter algorithm %d", static_cast(algo))); } +void ConvolutionBackwardFilterImpl::AlgoPack::fill_dwconv_algos() { + { + using AlgoParam = AlgoFloat32NCHWFMAImplicitBatchedGemm::AlgoParam; + /// preferred algo + implbmm_nchw_fma.emplace_back(AlgoParam{64, 128, 8, 32, 64, 8, 2}); + implbmm_nchw_fma.emplace_back(AlgoParam{128, 128, 8, 32, 64, 8, 2}); + implbmm_nchw_fma.emplace_back(AlgoParam{128, 64, 8, 64, 32, 8, 2}); + implbmm_nchw_fma.emplace_back(AlgoParam{128, 32, 8, 64, 32, 8, 2}); + implbmm_nchw_fma.emplace_back(AlgoParam{32, 128, 8, 32, 64, 8, 2}); + implbmm_nchw_fma.emplace_back(AlgoParam{64, 64, 8, 32, 64, 8, 2}); + implbmm_nchw_fma.emplace_back(AlgoParam{32, 64, 8, 32, 64, 8, 2}); + implbmm_nchw_fma.emplace_back(AlgoParam{32, 32, 8, 32, 32, 8, 2}); + implbmm_nchw_fma.emplace_back(AlgoParam{64, 32, 8, 64, 32, 8, 2}); + for (auto&& algo : implbmm_nchw_fma) { + all_algos.push_back(&algo); + } + } +#if CUDA_VERSION >= 10010 + { + using AlgoParam = AlgoFloat16NCHWHMMAImplicitBatchedGemm::AlgoParam; + /// preferred algo + implbmm_nchw_hmma.emplace_back(AlgoParam{64, 128, 32, 32, 32, 32, 8, 8, 4, 2}); + implbmm_nchw_hmma.emplace_back(AlgoParam{128, 128, 32, 32, 32, 32, 8, 8, 4, 2}); + implbmm_nchw_hmma.emplace_back(AlgoParam{128, 256, 32, 64, 64, 32, 8, 8, 4, 2}); + implbmm_nchw_hmma.emplace_back(AlgoParam{128, 64, 32, 32, 32, 32, 8, 8, 4, 2}); + implbmm_nchw_hmma.emplace_back(AlgoParam{64, 64, 32, 32, 32, 32, 8, 8, 4, 2}); + for (auto&& algo : implbmm_nchw_hmma) { + all_algos.push_back(&algo); + } + } +#endif +} + ConvolutionBackwardFilterImpl::AlgoPack ConvolutionBackwardFilterImpl::sm_algo_pack; ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( diff --git a/dnn/src/cuda/convolution/backward_filter/algo.h b/dnn/src/cuda/convolution/backward_filter/algo.h index 9f8cd758..f4869416 100644 --- a/dnn/src/cuda/convolution/backward_filter/algo.h +++ b/dnn/src/cuda/convolution/backward_filter/algo.h @@ -37,6 +37,8 @@ public: CUDA_CHANWISE, CUDA_BFLOAT16, CUDA_GROUP_CONV_GENERAL, + CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32, + CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16, }; using Mapper = std::unordered_map; @@ -210,9 +212,86 @@ private: WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; }; +class ConvolutionBackwardFilterImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm final + : public AlgoBase { +public: + struct AlgoParam { + int threadblock_m; + int threadblock_n; + int threadblock_k; + int warp_m; + int warp_n; + int warp_k; + int stage; + std::string to_string() { + return ssprintf( + "_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n, + threadblock_k, warp_m, warp_n, warp_k, stage); + } + }; + AlgoFloat32NCHWFMAImplicitBatchedGemm(AlgoParam algo_param) + : m_algo_param{algo_param}, + m_name{ssprintf( + "FLOAT32_NCHW_FMA_IMPLICIT_BATCHED_GEMM%s", + m_algo_param.to_string().c_str())} {} + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override { return 0; } + void exec(const ExecArgs& args) const override; + const char* name() const override { return m_name.c_str(); } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32) + +private: + const void* get_available_op(const SizeArgs& args) const; + AlgoParam m_algo_param; + std::string m_name; +}; + +class ConvolutionBackwardFilterImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm final + : public AlgoBase { +public: + /// add instruction shape as member of algo param, because f16 tensor core has 2 + /// different matrix shapes (i.e. mma.884 and mma.1688) + struct AlgoParam { + int threadblock_m; + int threadblock_n; + int threadblock_k; + int warp_m; + int warp_n; + int warp_k; + int instruction_m; + int instruction_n; + int instruction_k; + int stage; + std::string to_string() { + return ssprintf( + "_%dX%dX%d_%dX%dX%d_mma%dX%dX%d_%dstage", threadblock_m, + threadblock_n, threadblock_k, warp_m, warp_n, warp_k, instruction_m, + instruction_n, instruction_k, stage); + } + }; + AlgoFloat16NCHWHMMAImplicitBatchedGemm(AlgoParam algo_param) + : m_algo_param{algo_param}, + m_name{ssprintf( + "FLOAT16_NCHW_HMMA_IMPLICIT_BATCHED_GEMM%s", + m_algo_param.to_string().c_str())} {} + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + const char* name() const override { return m_name.c_str(); } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16) + +private: + const void* get_available_op(const SizeArgs& args) const; + AlgoParam m_algo_param; + std::string m_name; +}; + class ConvolutionBackwardFilterImpl::AlgoPack : NonCopyableObj { // defined in cudnn.cpp void fill_cudnn_algos(); + void fill_dwconv_algos(); AlgoBase::Mapper m_all_algos_map; @@ -224,6 +303,8 @@ public: AlgoChanwise chanwise; AlgoGroupConvGeneral group; AlgoBFloat16 bfloat16; + std::vector implbmm_nchw_fma; + std::vector implbmm_nchw_hmma; std::vector //! all algorithms diff --git a/dnn/src/cuda/convolution/backward_filter/implicit_batched_gemm_float16_nchw_hmma.cpp b/dnn/src/cuda/convolution/backward_filter/implicit_batched_gemm_float16_nchw_hmma.cpp new file mode 100644 index 00000000..4289d184 --- /dev/null +++ b/dnn/src/cuda/convolution/backward_filter/implicit_batched_gemm_float16_nchw_hmma.cpp @@ -0,0 +1,172 @@ +/** + * \file + * dnn/src/cuda/convolution/backward_filter/implicit_batched_gemm_float16_nchw_hmma.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/cuda/convolution/backward_filter/algo.h" +#include "src/cuda/cutlass/singleton.h" +#include "src/cuda/utils.h" + +using namespace megdnn; +using namespace cuda; +using namespace cutlass::library; + +const void* ConvolutionBackwardFilterImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm:: + get_available_op(const SizeArgs& args) const { + auto get_alignment = [](const TensorLayout& layout) { + int alignment = 0; + int width = layout.dtype.size(layout[3]); + for (int candidate : {16, 4, 2}) { + if (width % candidate == 0) { + alignment = candidate; + break; + } + } + alignment /= layout.dtype.size(1); + return alignment; + }; + int alignment_src = get_alignment(*args.src_layout); + int alignment_diff = get_alignment(*args.diff_layout); + megdnn_assert(alignment_src >= 1 && alignment_diff >= 1); + NumericTypeID accumulator_dtype = + args.opr->param().compute_mode == param::Convolution::ComputeMode::DEFAULT + ? NumericTypeID::kF16 + : NumericTypeID::kF32; + ConvolutionKey key{ + cutlass::conv::Operator::kWgrad, + NumericTypeID::kF16, // src tensor data type + LayoutTypeID::kTensorNCHW, // src tensor layout + NumericTypeID::kF16, // diff tensor data type + LayoutTypeID::kTensorNCHW, // diff tensor layout + NumericTypeID::kF32, // grad tensor data type + LayoutTypeID::kTensorNCHW, // grad tensor layout + NumericTypeID::kF32, // dummy argument, not used. + LayoutTypeID::kTensorNCHW, // dummy argument, not used + accumulator_dtype, + cutlass::conv::ConvType::kDepthwiseConvolution, + m_algo_param.threadblock_m, + m_algo_param.threadblock_n, + m_algo_param.threadblock_k, + m_algo_param.warp_m, + m_algo_param.warp_n, + m_algo_param.warp_k, + m_algo_param.instruction_m, + m_algo_param.instruction_n, + m_algo_param.instruction_k, + cutlass::epilogue::EpilogueType::kLinearCombination, // no bias + m_algo_param.stage, + cutlass::conv::SpecialOptimizeDesc::NONE, + alignment_src, + alignment_diff, + true}; + return (void*)Singleton::get().operation_table.find_op(key); +} + +bool ConvolutionBackwardFilterImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm:: + is_available(const SizeArgs& args) const { +#define RETURN_IF_FALSE(stmt_) \ + if (!(stmt_)) \ + return false; + RETURN_IF_FALSE(is_compute_capability_required(7, 0)); + RETURN_IF_FALSE( + args.src_layout->is_contiguous() && args.diff_layout->is_contiguous() && + args.grad_layout->is_contiguous()); + using Param = param::Convolution; + using Format = Param::Format; + using Sparse = Param::Sparse; + using Mode = Param::Mode; + using ComputeMode = Param::ComputeMode; + auto&& param = args.opr->param(); + auto&& fm = args.grad_filter_meta; + RETURN_IF_FALSE(param.compute_mode == ComputeMode::FLOAT32); + RETURN_IF_FALSE( + param.format == Format::NCHW && + args.src_layout->dtype.enumv() == DTypeEnum::Float16 && + args.diff_layout->dtype.enumv() == DTypeEnum::Float16 && + args.grad_layout->dtype.enumv() == DTypeEnum::Float16); + RETURN_IF_FALSE(param.sparse == Sparse::GROUP); + RETURN_IF_FALSE(param.mode == Mode::CROSS_CORRELATION); + // check if channelwise convolution + RETURN_IF_FALSE(fm.icpg == 1 && fm.ocpg == 1); + RETURN_IF_FALSE(param.dilate_h == 1 && param.dilate_w == 1); + const auto* op = get_available_op(args); + RETURN_IF_FALSE(op != nullptr); + return true; +#undef RETURN_IF_FALSE +} + +size_t ConvolutionBackwardFilterImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm:: + get_workspace_in_bytes(const SizeArgs& args) const { + auto layout = *args.grad_layout; + // modify data type + layout.modify_dtype_inplace(dtype::Float32()); + return layout.span().dist_byte(); +} + +void ConvolutionBackwardFilterImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm::exec( + const ExecArgs& args) const { + auto&& param = args.opr->param(); + auto&& fm = args.grad_filter_meta; + int hi = args.src_layout->operator[](2), wi = args.src_layout->operator[](3); + int n = args.diff_layout->operator[](0), ho = args.diff_layout->operator[](2), + wo = args.diff_layout->operator[](3); + int co = fm.group, ci = co, groups = co; + int fh = fm.spatial[0], fw = fm.spatial[1]; + int sh = fm.stride[0], sw = fm.stride[1]; + int ph = fm.padding[0], pw = fm.padding[1]; + int dh = param.dilate_h, dw = param.dilate_w; + + // check if channelwise convolution + megdnn_assert(fm.icpg == 1 && fm.ocpg == 1); + auto&& stream = cuda_stream(args.opr->handle()); + + float alpha = 1.f; + float beta = 0.f; + + const Operation* op = (const Operation*)get_available_op(args); + + cutlass::conv::Conv2dProblemSize problem_size{ + n, hi, wi, ci, co, fh, fw, ho, + wo, ph, pw, sh, sw, dh, dw, cutlass::conv::Mode::kCrossCorrelation, + 1, // split k slices, always 1 + groups, // groups + }; + + cutlass::library::ConvolutionArguments conv_args{ + problem_size, + args.src_tensor->raw_ptr(), + args.diff_tensor->raw_ptr(), + nullptr, + nullptr, + args.workspace.raw_ptr, + &alpha, + &beta, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr}; + + cutlass_check(op->run(&conv_args, nullptr, stream)); + + after_kernel_launch(); + + auto&& typecvt = args.opr->handle()->create_operator(); + auto f32_grad_layout = *args.grad_layout; + // modify data type + f32_grad_layout.modify_dtype_inplace(dtype::Float32()); + TensorND src{args.workspace.raw_ptr, f32_grad_layout}, + dst{args.grad_tensor->raw_ptr(), *args.grad_layout}; + typecvt->exec(src, dst); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution/backward_filter/implicit_batched_gemm_float32_nchw_fma.cpp b/dnn/src/cuda/convolution/backward_filter/implicit_batched_gemm_float32_nchw_fma.cpp new file mode 100644 index 00000000..68491ac2 --- /dev/null +++ b/dnn/src/cuda/convolution/backward_filter/implicit_batched_gemm_float32_nchw_fma.cpp @@ -0,0 +1,135 @@ +/** + * \file + * dnn/src/cuda/convolution/backward_filter/implicit_batched_gemm_float32_nchw_fma.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/cuda/convolution/backward_filter/algo.h" +#include "src/cuda/cutlass/singleton.h" +#include "src/cuda/utils.h" + +using namespace megdnn; +using namespace cuda; +using namespace cutlass::library; + +const void* ConvolutionBackwardFilterImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm:: + get_available_op(const SizeArgs& args) const { + ConvolutionKey key{ + cutlass::conv::Operator::kWgrad, + NumericTypeID::kF32, // src tensor data type + LayoutTypeID::kTensorNCHW, // src tensor layout + NumericTypeID::kF32, // diff tensor data type + LayoutTypeID::kTensorNCHW, // diff tensor layout + NumericTypeID::kF32, // grad tensor data type + LayoutTypeID::kTensorNCHW, // grad tensor layout + NumericTypeID::kF32, // dummy argument, not used. + LayoutTypeID::kTensorNCHW, // dummy argument, not used + NumericTypeID::kF32, + cutlass::conv::ConvType::kDepthwiseConvolution, + m_algo_param.threadblock_m, + m_algo_param.threadblock_n, + m_algo_param.threadblock_k, + m_algo_param.warp_m, + m_algo_param.warp_n, + m_algo_param.warp_k, + 1, + 1, + 1, + cutlass::epilogue::EpilogueType::kLinearCombination, // no bias + m_algo_param.stage, + cutlass::conv::SpecialOptimizeDesc::NONE, + 1, + 1, + true}; + return (void*)Singleton::get().operation_table.find_op(key); +} + +bool ConvolutionBackwardFilterImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm::is_available( + const SizeArgs& args) const { +#define RETURN_IF_FALSE(stmt_) \ + if (!(stmt_)) \ + return false; + RETURN_IF_FALSE(is_compute_capability_required(6, 1)); + RETURN_IF_FALSE( + args.src_layout->is_contiguous() && args.diff_layout->is_contiguous() && + args.grad_layout->is_contiguous()); + using Param = param::Convolution; + using Format = Param::Format; + using Sparse = Param::Sparse; + using Mode = Param::Mode; + auto&& param = args.opr->param(); + auto&& fm = args.grad_filter_meta; + RETURN_IF_FALSE( + param.format == Format::NCHW && + args.src_layout->dtype.enumv() == DTypeEnum::Float32 && + args.diff_layout->dtype.enumv() == DTypeEnum::Float32 && + args.grad_layout->dtype.enumv() == DTypeEnum::Float32); + RETURN_IF_FALSE(param.sparse == Sparse::GROUP); + RETURN_IF_FALSE(param.mode == Mode::CROSS_CORRELATION); + // check if channelwise convolution + RETURN_IF_FALSE(fm.icpg == 1 && fm.ocpg == 1); + RETURN_IF_FALSE(param.dilate_h == 1 && param.dilate_w == 1); + const auto* op = get_available_op(args); + RETURN_IF_FALSE(op != nullptr); + return true; +#undef RETURN_IF_FALSE +} + +void ConvolutionBackwardFilterImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm::exec( + const ExecArgs& args) const { + auto&& param = args.opr->param(); + auto&& fm = args.grad_filter_meta; + int hi = args.src_layout->operator[](2), wi = args.src_layout->operator[](3); + int n = args.diff_layout->operator[](0), ho = args.diff_layout->operator[](2), + wo = args.diff_layout->operator[](3); + int co = fm.group, ci = co, groups = co; + int fh = fm.spatial[0], fw = fm.spatial[1]; + int sh = fm.stride[0], sw = fm.stride[1]; + int ph = fm.padding[0], pw = fm.padding[1]; + int dh = param.dilate_h, dw = param.dilate_w; + + // check if channelwise convolution + megdnn_assert(fm.icpg == 1 && fm.ocpg == 1); + auto&& stream = cuda_stream(args.opr->handle()); + + float alpha = 1.f; + float beta = 0.f; + + const Operation* op = (const Operation*)get_available_op(args); + + cutlass::conv::Conv2dProblemSize problem_size{ + n, hi, wi, ci, co, fh, fw, ho, + wo, ph, pw, sh, sw, dh, dw, cutlass::conv::Mode::kCrossCorrelation, + 1, // split k slices, always 1 + groups, // groups + }; + + cutlass::library::ConvolutionArguments conv_args{ + problem_size, + args.src_tensor->raw_ptr(), + args.diff_tensor->raw_ptr(), + nullptr, + nullptr, + args.grad_tensor->raw_ptr(), + &alpha, + &beta, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr}; + + cutlass_check(op->run(&conv_args, nullptr, stream)); + + after_kernel_launch(); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution/opr_impl.cpp b/dnn/src/cuda/convolution/opr_impl.cpp index c1d1784d..eefacf57 100644 --- a/dnn/src/cuda/convolution/opr_impl.cpp +++ b/dnn/src/cuda/convolution/opr_impl.cpp @@ -116,15 +116,18 @@ ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl:: AlgoBase::SizeArgs args(this, filter, diff, grad); //! choose for large kernel cases - size_t fh = args.filter_meta.spatial[2], fw = args.filter_meta.spatial[3]; + size_t fh = args.filter_meta.spatial[0], fw = args.filter_meta.spatial[1]; size_t ho = diff[2], wo = diff[3]; const bool prefer_dnn_lk_implbmm = args.filter_meta.format == Param::Format::NCHW && ho <= 2 * fh && wo <= 2 * fw; if (prefer_dnn_lk_implbmm) { - if (sm_algo_pack.implbmm_nchw_hmma.is_available_attribute( +#if CUDA_VERSION >= 10020 + if (sm_algo_pack.implbmm_nchw_hmma[0].is_available_attribute( args, positive_attr, negative_attr, workspace_limit_in_bytes)) return &sm_algo_pack.implbmm_nchw_hmma[0]; - if (sm_algo_pack.implbmm_nchw_fma.is_available_attribute(args, positive_attr, negative_attr, workspace_limit_in_bytes)) +#endif + if (sm_algo_pack.implbmm_nchw_fma[0].is_available_attribute( + args, positive_attr, negative_attr, workspace_limit_in_bytes)) return &sm_algo_pack.implbmm_nchw_fma[0]; } @@ -255,6 +258,23 @@ ConvolutionBackwardFilterImpl::Algorithm* ConvolutionBackwardFilterImpl:: const AlgoAttribute& negative_attr) { AlgoBase::SizeArgs args(this, src, diff, grad); + //! choose for large kernel cases + size_t fh = args.grad_filter_meta.spatial[0], fw = args.grad_filter_meta.spatial[1]; + size_t ho = diff[2], wo = diff[3]; + const bool prefer_dnn_lk_implbmm = + args.grad_filter_meta.format == Param::Format::NCHW && ho <= 2 * fh && + wo <= 2 * fw; + if (prefer_dnn_lk_implbmm) { +#if CUDA_VERSION >= 10020 + if (sm_algo_pack.implbmm_nchw_hmma[0].is_available_attribute( + args, positive_attr, negative_attr, workspace_limit_in_bytes)) + return &sm_algo_pack.implbmm_nchw_hmma[0]; +#endif + if (sm_algo_pack.implbmm_nchw_fma[0].is_available_attribute( + args, positive_attr, negative_attr, workspace_limit_in_bytes)) + return &sm_algo_pack.implbmm_nchw_fma[0]; + } + if (args.grad_filter_meta.group > 1 && sm_algo_pack.chanwise.is_available_attribute( args, positive_attr, negative_attr, workspace_limit_in_bytes)) { diff --git a/dnn/src/cuda/convolution/opr_impl.h b/dnn/src/cuda/convolution/opr_impl.h index 0a3c22b7..4763a1af 100644 --- a/dnn/src/cuda/convolution/opr_impl.h +++ b/dnn/src/cuda/convolution/opr_impl.h @@ -156,6 +156,8 @@ public: class AlgoChanwise; class AlgoGroupConvGeneral; class AlgoBFloat16; + class AlgoFloat32NCHWFMAImplicitBatchedGemm; + class AlgoFloat16NCHWHMMAImplicitBatchedGemm; class AlgoPack; diff --git a/dnn/src/cuda/cutlass/convolution_operation.h b/dnn/src/cuda/cutlass/convolution_operation.h index b46b1cc5..0004bd93 100644 --- a/dnn/src/cuda/cutlass/convolution_operation.h +++ b/dnn/src/cuda/cutlass/convolution_operation.h @@ -136,6 +136,15 @@ template struct init_epilogue_param_; template +struct init_epilogue_param_ { + using ElementCompute = typename EpilogueOp::ElementCompute; + typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { + return {*static_cast(conv_args->alpha), + *static_cast(conv_args->beta)}; + } +}; + +template struct init_epilogue_param_< EpilogueOp, epilogue::EpilogueType::kBiasAddLinearCombination> { using ElementCompute = typename EpilogueOp::ElementCompute; @@ -290,6 +299,159 @@ public: /////////////////////////////////////////////////////////////////////////////////////////////////// +/// We add a new template class to handle convolution backward filter operation, because +/// the device-level convolution operator of backward filter is different from the +/// others (convolution forward and convolution backward data). +/// But the description object is reused in this wrapper of convolution backward filter. +/// The reason is that we do not want to introduce an another unnecessary structure. +/// TODO: Maybe the device-level operator in cutlass for convoluton forward, backward +/// data and backward filter should be combined. +template +class ConvolutionBackwardFilterOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementSrc = typename Operator::ElementSrc; + using LayoutSrc = typename Operator::LayoutSrc; + using ElementDiff = typename Operator::ElementDiff; + using LayoutDiff = typename Operator::LayoutDiff; + using ElementGrad = typename Operator::ElementGrad; + using LayoutGrad = typename Operator::LayoutGrad; + using ElementAccumulator = typename Operator::ElementAccumulator; + + ConvolutionBackwardFilterOperationBase(char const* name = "unknown_convolution") { + m_description.name = name; + m_description.provider = Provider::kCUTLASS; + m_description.kind = OperationKind::kConvolution; + m_description.conv_op = Operator::kConvolutionalOperator; + + m_description.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + m_description.tile_description.threadblock_stages = Operator::kStages; + + m_description.tile_description.warp_count = make_Coord( + Operator::ConvolutionKernel::WarpCount::kM, + Operator::ConvolutionKernel::WarpCount::kN, + Operator::ConvolutionKernel::WarpCount::kK); + + m_description.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + m_description.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + m_description.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + m_description.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + m_description.tile_description.minimum_compute_capability = + ArchMap::kMin; + + m_description.tile_description.maximum_compute_capability = + ArchMap::kMax; + + /// src in description -> src in C++ template + m_description.src = + make_TensorDescription(Operator::kAlignmentSrc); + /// filter in description -> diff in C++ template + m_description.filter = make_TensorDescription( + Operator::kAlignmentDiff); + /// dst in description -> grad in C++ template + m_description.dst = make_TensorDescription( + Operator::kAlignmentGrad); + /// because bias tensor is not used in ConvolutionBackwardFilter operation, the + /// following tensor description is a dummy arguments + m_description.bias = make_TensorDescription( + Operator::kAlignmentGrad); + + m_description.convolution_type = Operator::kConvolutionType; + m_description.arch_tag = ArchTagMap::kId; + + m_description.epilogue_type = Operator::EpilogueOutputOp::kType; + m_description.epilogue_count = Operator::EpilogueOutputOp::kCount; + + m_description.threadblock_swizzle = + ThreadblockSwizzleMap::kId; + + m_description.special_optimization = Operator::kSpecialOpt; + m_description.gemm_mode = Operator::kGemmMode; + /// ConvolutionBackwardFilter operation is only used for depthwise convolution, + /// so the option without_shared_load is always true + m_description.without_shared_load = true; + } + + virtual OperationDescription const& description() const { return m_description; } + +protected: + ConvolutionDescription m_description; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class ConvolutionBackwardFilterOperation + : public ConvolutionBackwardFilterOperationBase { +public: + using Operator = Operator_; + using ElementSrc = typename Operator::ElementSrc; + using LayoutSrc = typename Operator::LayoutSrc; + using ElementDiff = typename Operator::ElementDiff; + using LayoutDiff = typename Operator::LayoutDiff; + using ElementGrad = typename Operator::ElementGrad; + using LayoutGrad = typename Operator::LayoutGrad; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + + ConvolutionBackwardFilterOperation(char const* name = "unknown_gemm") + : ConvolutionBackwardFilterOperationBase(name) {} + + virtual Status run( + void const* arguments_ptr, void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + cutlass::conv::Operator conv_op = this->m_description.conv_op; + ConvolutionArguments const* conv_args = + reinterpret_cast(arguments_ptr); + const auto& ps = conv_args->problem_size; + + OperatorArguments args; + args.problem_size = ps; + /// src in convolution arguments -> ref_src + args.ref_src = { + static_cast(const_cast(conv_args->src)), + LayoutSrc::packed(implicit_gemm_tensor_b_extent(conv_op, ps))}; + /// filter in convolution arguments -> ref_diff + args.ref_diff = { + static_cast(const_cast(conv_args->filter)), + LayoutDiff::packed(implicit_gemm_tensor_a_extent(conv_op, ps))}; + /// dst in convolution arguments -> ref_grad + args.ref_grad = { + static_cast(conv_args->dst), + LayoutGrad::packed(implicit_gemm_tensor_c_extent(conv_op, ps))}; + + args.output_op = init_epilogue_param().get( + conv_args); + + Operator op; + Status status = op.initialize(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + return op.run(stream); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace library } // namespace cutlass diff --git a/dnn/src/cuda/cutlass/initialize_all.cu b/dnn/src/cuda/cutlass/initialize_all.cu index 44d6faf3..e51c5ec5 100644 --- a/dnn/src/cuda/cutlass/initialize_all.cu +++ b/dnn/src/cuda/cutlass/initialize_all.cu @@ -45,6 +45,11 @@ namespace library { ///////////////////////////////////////////////////////////////////////////////////////////////// #if ((__CUDACC_VER_MAJOR__ > 10) || \ + (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1)) +#define CUTLASS_ARCH_MMA_SM70_SUPPORTED 1 +#endif + +#if ((__CUDACC_VER_MAJOR__ > 10) || \ (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) #define CUTLASS_ARCH_MMA_SM75_SUPPORTED 1 #endif @@ -56,14 +61,18 @@ void initialize_all_conv2d_simt_operations(Manifest& manifest); void initialize_all_deconv_simt_operations(Manifest& manifest); void initialize_all_dwconv2d_fprop_simt_operations(Manifest& manifest); void initialize_all_dwconv2d_dgrad_simt_operations(Manifest& manifest); -#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) && CUTLASS_ARCH_MMA_SM75_SUPPORTED +void initialize_all_dwconv2d_wgrad_simt_operations(Manifest& manifest); +#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) && CUTLASS_ARCH_MMA_SM70_SUPPORTED void initialize_all_gemm_tensorop884_operations(Manifest& manifest); +void initialize_all_dwconv2d_fprop_tensorop884_operations(Manifest& manifest); +void initialize_all_dwconv2d_dgrad_tensorop884_operations(Manifest& manifest); +void initialize_all_dwconv2d_wgrad_tensorop884_operations(Manifest& manifest); +#endif +#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) && CUTLASS_ARCH_MMA_SM75_SUPPORTED void initialize_all_gemm_tensorop1688_operations(Manifest& manifest); void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest); void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest); void initialize_all_deconv_tensorop8816_operations(Manifest& manifest); -void initialize_all_dwconv2d_fprop_tensorop884_operations(Manifest& manifest); -void initialize_all_dwconv2d_dgrad_tensorop884_operations(Manifest& manifest); #endif void initialize_all(Manifest& manifest) { @@ -72,14 +81,18 @@ void initialize_all(Manifest& manifest) { initialize_all_deconv_simt_operations(manifest); initialize_all_dwconv2d_fprop_simt_operations(manifest); initialize_all_dwconv2d_dgrad_simt_operations(manifest); -#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) && CUTLASS_ARCH_MMA_SM75_SUPPORTED + initialize_all_dwconv2d_wgrad_simt_operations(manifest); +#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) && CUTLASS_ARCH_MMA_SM70_SUPPORTED initialize_all_gemm_tensorop884_operations(manifest); + initialize_all_dwconv2d_fprop_tensorop884_operations(manifest); + initialize_all_dwconv2d_dgrad_tensorop884_operations(manifest); + initialize_all_dwconv2d_wgrad_tensorop884_operations(manifest); +#endif +#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) && CUTLASS_ARCH_MMA_SM75_SUPPORTED initialize_all_gemm_tensorop1688_operations(manifest); initialize_all_conv2d_tensorop8816_operations(manifest); initialize_all_conv2d_tensorop8832_operations(manifest); initialize_all_deconv_tensorop8816_operations(manifest); - initialize_all_dwconv2d_fprop_tensorop884_operations(manifest); - initialize_all_dwconv2d_dgrad_tensorop884_operations(manifest); #endif } diff --git a/dnn/src/cuda/cutlass/operation_table.h b/dnn/src/cuda/cutlass/operation_table.h index 1d4f3bcb..baa9e1ed 100644 --- a/dnn/src/cuda/cutlass/operation_table.h +++ b/dnn/src/cuda/cutlass/operation_table.h @@ -280,7 +280,6 @@ struct ConvolutionKeyHasher { inline size_t operator()(ConvolutionKey const& key) const { return Hash() .update(&key.conv_op, sizeof(key.conv_op)) - .update(&key.conv_op, sizeof(key.conv_op)) .update(&key.element_src, sizeof(key.element_src)) .update(&key.layout_src, sizeof(key.layout_src)) .update(&key.element_filter, sizeof(key.element_filter)) diff --git a/dnn/src/cuda/cutlass/util.cu b/dnn/src/cuda/cutlass/util.cu index a309a66a..49fd26cb 100644 --- a/dnn/src/cuda/cutlass/util.cu +++ b/dnn/src/cuda/cutlass/util.cu @@ -1322,6 +1322,8 @@ static struct { {"batch_convolution", "BatchConvolution", conv::ConvType::kBatchConvolution}, {"local", "Local", conv::ConvType::kLocal}, {"local_share", "LocalShare", conv::ConvType::kLocalShare}, + {"depthwise_convolution", "DepthwiseConvolution", + conv::ConvType::kDepthwiseConvolution}, }; /// Converts a ConvType enumerant to a string diff --git a/dnn/src/cuda/matrix_mul/algos.cpp b/dnn/src/cuda/matrix_mul/algos.cpp index 1aaf8609..3a12ffe8 100644 --- a/dnn/src/cuda/matrix_mul/algos.cpp +++ b/dnn/src/cuda/matrix_mul/algos.cpp @@ -44,7 +44,7 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { for (auto&& algo : simt_float32_gemv_batched_strided) { all_algos.push_back(&algo); } -#if CUDA_VERSION >= 10020 +#if CUDA_VERSION >= 10010 for (auto&& algo : tensorop_float16) { all_algos.push_back(&algo); } @@ -113,21 +113,26 @@ void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() { simt_float32_gemv_batched_strided.emplace_back(128); simt_float32_gemv_batched_strided.emplace_back(64); simt_float32_gemv_batched_strided.emplace_back(32); -#define FOREACH_CUTLASS_MATMUL_F16_SHAPES(cb) \ - cb(256, 128, 32, 64, 64, 32, 8, 8, 4); \ - cb(128, 256, 32, 64, 64, 32, 8, 8, 4); \ - cb(128, 128, 32, 64, 64, 32, 8, 8, 4); \ - cb(256, 128, 32, 64, 64, 32, 16, 8, 8); \ - cb(128, 256, 32, 64, 64, 32, 16, 8, 8); \ +#define FOREACH_CUTLASS_MATMUL_MMA_SM70_SHAPES(cb) \ + cb(256, 128, 32, 64, 64, 32, 8, 8, 4); \ + cb(128, 256, 32, 64, 64, 32, 8, 8, 4); \ + cb(128, 128, 32, 64, 64, 32, 8, 8, 4); +#define FOREACH_CUTLASS_MATMUL_MMA_SM75_SHAPES(cb) \ + cb(256, 128, 32, 64, 64, 32, 16, 8, 8); \ + cb(128, 256, 32, 64, 64, 32, 16, 8, 8); \ cb(128, 128, 32, 64, 64, 32, 16, 8, 8); #define cb(...) \ tensorop_float16.emplace_back(AlgoParam{__VA_ARGS__}); \ tensorop_float16_split_k.emplace_back(AlgoParam{__VA_ARGS__}); +#if CUDA_VERSION >= 10010 + FOREACH_CUTLASS_MATMUL_MMA_SM70_SHAPES(cb) +#endif #if CUDA_VERSION >= 10020 - FOREACH_CUTLASS_MATMUL_F16_SHAPES(cb) + FOREACH_CUTLASS_MATMUL_MMA_SM75_SHAPES(cb) #endif #undef cb -#undef FOREACH_CUTLASS_MATMUL_F16_SHAPES +#undef FOREACH_CUTLASS_MATMUL_MMA_SM70_SHAPES +#undef FOREACH_CUTLASS_MATMUL_MMA_SM75_SHAPES } #endif diff --git a/dnn/src/cuda/matrix_mul/algos.h b/dnn/src/cuda/matrix_mul/algos.h index 7b76b3a9..c2eaf7fa 100644 --- a/dnn/src/cuda/matrix_mul/algos.h +++ b/dnn/src/cuda/matrix_mul/algos.h @@ -350,7 +350,7 @@ private: std::string m_name; }; -#if CUDA_VERSION >= 10020 +#if CUDA_VERSION >= 10010 class MatrixMulForwardImpl::AlgoFloat16TensorOp final : public AlgoCutlassMatrixMulBase { public: @@ -418,7 +418,7 @@ public: std::vector simt_float32; std::vector simt_float32_split_k; std::vector simt_float32_gemv_batched_strided; -#if CUDA_VERSION >= 10020 +#if CUDA_VERSION >= 10010 std::vector tensorop_float16; std::vector tensorop_float16_split_k; #endif diff --git a/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop.cpp b/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop.cpp index 9afc2d08..fe479999 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop.cpp +++ b/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop.cpp @@ -15,7 +15,7 @@ #include "src/cuda/matrix_mul/algos.h" #include "src/cuda/utils.h" -#if CUDA_VERSION >= 10020 +#if CUDA_VERSION >= 10010 using namespace megdnn; using namespace cuda; diff --git a/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp b/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp index 61050696..a3ef56fb 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp +++ b/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp @@ -15,7 +15,7 @@ #include "src/cuda/matrix_mul/algos.h" #include "src/cuda/utils.h" -#if CUDA_VERSION >= 10020 +#if CUDA_VERSION >= 10010 using namespace megdnn; using namespace cuda; diff --git a/dnn/src/cuda/matrix_mul/opr_impl.h b/dnn/src/cuda/matrix_mul/opr_impl.h index 69999fe5..38529a9f 100644 --- a/dnn/src/cuda/matrix_mul/opr_impl.h +++ b/dnn/src/cuda/matrix_mul/opr_impl.h @@ -46,9 +46,11 @@ public: class AlgoFloat32SIMT; class AlgoFloat32SIMTSplitK; class AlgoFloat32SIMTGemvBatchedStrided; +#if CUDA_VERSION >= 10010 class AlgoFloat16TensorOp; class AlgoFloat16TensorOpSplitK; #endif +#endif class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } diff --git a/dnn/test/cuda/chanwise_convolution.cpp b/dnn/test/cuda/chanwise_convolution.cpp index d154a843..edbc582d 100644 --- a/dnn/test/cuda/chanwise_convolution.cpp +++ b/dnn/test/cuda/chanwise_convolution.cpp @@ -494,6 +494,21 @@ void check_chanwise(DType io_type, DType comp_type, Handle* handle, const char* checker.set_param(gconv_param({M, 7, 7, 2, 2}, io16xc32)) .execs({{2, 1, 1, 15, 15}, {8, 2, 7, 7}, {8, 2, 14, 14}}); } else if (std::is_same::value) { + // align 8 + checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32)) + .execs({{8, 2, 16, 16}, {8, 2, 16, 16}, {2, 1, 1, 15, 15}}); + // align 1 + checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32)) + .execs({{8, 2, 15, 15}, {8, 2, 15, 15}, {2, 1, 1, 15, 15}}); + // align 2 + checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32)) + .execs({{8, 2, 14, 14}, {8, 2, 14, 14}, {2, 1, 1, 15, 15}}); + // custom padding + checker.set_param(gconv_param({M, 3, 3, 1, 1}, io16xc32)) + .execs({{8, 2, 16, 16}, {8, 2, 8, 8}, {2, 1, 1, 15, 15}}); + // custom stride + checker.set_param(gconv_param({M, 7, 7, 2, 2}, io16xc32)) + .execs({{8, 2, 14, 14}, {8, 2, 7, 7}, {2, 1, 1, 15, 15}}); } } } // namespace @@ -535,14 +550,32 @@ MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FMA_KERNEL(cb) #undef cb +#define cb(tag, tbm, tbn, tbk, wm, wn, wk) \ + TEST_F(CUDA, CHANWISE_CONVOLUTION_BACKWARD_FILTER_CUTLASS_FMA_##tag) { \ + require_compute_capability(6, 1); \ + check_chanwise( \ + dtype::Float32(), dtype::Float32(), handle_cuda(), \ + "FLOAT32_NCHW_FMA_IMPLICIT_BATCHED_GEMM_" #tbm "X" #tbn "X" #tbk \ + "_" #wm "X" #wn "X" #wk "_2stage"); \ + } + +MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FMA_KERNEL(cb) + +#undef cb + #undef MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FMA_KERNEL +#if CUDA_VERSION >= 10010 #define MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_HMMA_KERNEL(cb) \ cb(1, 128, 128, 32, 32, 32, 32); \ cb(2, 128, 256, 32, 64, 64, 32); \ cb(3, 128, 64, 32, 32, 32, 32); \ cb(4, 64, 128, 32, 32, 32, 32); \ cb(5, 64, 64, 32, 32, 32, 32); +#else +// hmma instruction need cuda version >= 10.2, disable hmma testcases in this path +#define MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_HMMA_KERNEL(cb) +#endif // check both ioc16 and io16xc32 #define cb(tag, tbm, tbn, tbk, wm, wn, wk) \ @@ -579,6 +612,19 @@ MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_HMMA_KERNEL(cb) #undef cb +#define cb(tag, tbm, tbn, tbk, wm, wn, wk) \ + TEST_F(CUDA, CHANWISE_CONVOLUTION_BACKWARD_FILTER_CUTLASS_HMMA_##tag) { \ + require_compute_capability(7, 0); \ + check_chanwise( \ + dtype::Float16(), dtype::Float32(), handle_cuda(), \ + "FLOAT16_NCHW_HMMA_IMPLICIT_BATCHED_GEMM_" #tbm "X" #tbn "X" #tbk \ + "_" #wm "X" #wn "X" #wk "_mma8X8X4_2stage"); \ + } + +MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_HMMA_KERNEL(cb) + +#undef cb + #undef MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_HMMA_KERNEL #if MEGDNN_WITH_BENCHMARK @@ -1434,6 +1480,77 @@ TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_BACKWARD_DATA_LARGE_KERNEL) { } // clang-format on } + +TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_BACKWARD_FILTER_LARGE_KERNEL) { + CUBenchmarker bencher(handle_cuda()); + size_t RUNS = 100; + bencher.set_display(false).set_times(RUNS); + std::unique_ptr> proxy{ + new OprProxy{true}}; + bencher.set_proxy(proxy); + + Convolution::Param param; + param.format = ConvBias::Param::Format::NCHW; + param.sparse = Convolution::Param::Sparse::GROUP; + NormalRNG rng; + + auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) { + param.pad_h = f / 2; + param.pad_w = f / 2; + param.stride_h = s; + param.stride_w = s; + param.compute_mode = param::Convolution::ComputeMode::DEFAULT; + + TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f}; + + TensorLayout dst_layout; + auto opr = handle_cuda()->create_operator(); + opr->param() = param; + opr->deduce_layout( + {src, dtype::Float32()}, {filter, dtype::Float32()}, dst_layout); + float bandwith = static_cast( + src.total_nr_elems() + filter.total_nr_elems() + + dst_layout.total_nr_elems()) / + (1024 * 1024 * 1024) * 1e3; + + bencher.set_param(param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()) + .set_rng(0, &rng) + .set_rng(1, &rng); + bencher.proxy()->target_execution_policy = {}; + auto time_in_ms_fp32 = bencher.execs({src, src, filter}) / RUNS; + + bencher.set_param(param) + .set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_dtype(2, dtype::Float16()) + .set_rng(0, &rng) + .set_rng(1, &rng); + bencher.proxy()->target_execution_policy = {}; + param.compute_mode = param::Convolution::ComputeMode::FLOAT32; + bencher.set_param(param); + auto time_in_ms_pseudo_fp16 = bencher.execs({src, src, filter}) / RUNS; + + printf("stride=%zu src=%s, filter=%s, float32: %.2fms %.2fGB/s " + "pseudo float16: %.2fms %.2fGB/s " + "speedup: " + "%0.2f (fp16/fp32) \n", + s, src.to_string().c_str(), filter.to_string().c_str(), time_in_ms_fp32, + bandwith * 4 / time_in_ms_fp32, time_in_ms_pseudo_fp16, + bandwith * 2 / time_in_ms_pseudo_fp16, + time_in_ms_fp32 / time_in_ms_pseudo_fp16); + }; + + // clang-format off + for (size_t b : {32, 64}) + for (size_t f : {3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}) { + run(b, 384, 32, 32, f, 1); + run(b, 384, 64, 64, f, 1); + } + // clang-format on +} #endif // vim: syntax=cpp.doxygen diff --git a/dnn/test/cuda/conv_bias.cpp b/dnn/test/cuda/conv_bias.cpp index f08441e8..d4eff9ad 100644 --- a/dnn/test/cuda/conv_bias.cpp +++ b/dnn/test/cuda/conv_bias.cpp @@ -1093,8 +1093,11 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_GROUP) { run(2, 32, 7, 7, 3, 3, 64, 1, 1, 1, 1, 1, 1, 4, nlmode); // strided case run(2, 32, 7, 7, 3, 3, 64, 0, 0, 2, 2, 1, 1, 8, nlmode); + // dilate conv is supported in CUDNN since version 7.5.0 +#if CUDNN_VERSION >= 7500 // dilated case run(2, 32, 7, 7, 3, 3, 64, 0, 0, 1, 1, 2, 2, 8, nlmode); +#endif } } diff --git a/dnn/test/cuda/cutlass_matmul.cpp b/dnn/test/cuda/cutlass_matmul.cpp index a44a3b62..72b5b9f4 100644 --- a/dnn/test/cuda/cutlass_matmul.cpp +++ b/dnn/test/cuda/cutlass_matmul.cpp @@ -213,7 +213,7 @@ std::vector get_feat_model_args() { return args; } -#if CUDA_VERSION >= 10020 +#if CUDA_VERSION >= 10010 std::vector get_f16_feat_model_args() { std::vector args; args.emplace_back(BenchArgs{128, 9216, 9216}); @@ -367,7 +367,7 @@ MEGDNN_FOREACH_CUTLASS_KERNEL(cb) #undef cb #undef MEGDNN_FOREACH_CUTLASS_KERNEL -#if CUDA_VERSION >= 10020 +#if CUDA_VERSION >= 10010 #define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \ cb(1, 256, 128, 32, 64, 64, 32, 8, 8, 4); \ cb(2, 128, 256, 32, 64, 64, 32, 8, 8, 4); \ @@ -403,7 +403,9 @@ MEGDNN_FOREACH_CUTLASS_KERNEL(cb) #undef cb #undef MEGDNN_FOREACH_CUTLASS_KERNEL +#endif +#if CUDA_VERSION >= 10020 #define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \ cb(1, 256, 128, 32, 64, 64, 32, 16, 8, 8); \ cb(2, 128, 256, 32, 64, 64, 32, 16, 8, 8); \ @@ -454,7 +456,7 @@ TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL_FEAT) { dtype::Float32(), "CUTLASS_FLOAT32_SIMT"); } -#if CUDA_VERSION >= 10020 +#if CUDA_VERSION >= 10010 TEST_F(CUDA, BENCHMARK_CUTLASS_F16_MATMUL_FEAT) { benchmark_matrix_mul( handle_cuda(), get_f16_feat_model_args(), dtype::Float16(),