From 96050073a2cb703dd18ae753c60d11aa1c48cc1e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 7 Feb 2022 18:54:21 +0800 Subject: [PATCH] feat(dnn/cuda): add implicit bmm large kernel dwconv2d fprop impl GitOrigin-RevId: feb09ebb5836d26433c4a82940bb5f22795da381 --- dnn/scripts/cutlass_generator/conv2d_operation.py | 1060 +++++---- dnn/scripts/cutlass_generator/gemm_operation.py | 1729 ++++++++------ dnn/scripts/cutlass_generator/generator.py | 2404 +++++++++++++------- dnn/scripts/cutlass_generator/library.py | 893 ++++---- dnn/scripts/cutlass_generator/manifest.py | 546 +++-- dnn/src/CMakeLists.txt | 2 + dnn/src/cuda/conv_bias/algo.cpp | 27 + dnn/src/cuda/conv_bias/algo.h | 68 +- .../cuda/conv_bias/cutlass_convolution_base.cpp | 152 +- .../implicit_batched_gemm_float16_nchw_hmma.cpp | 95 + .../implicit_batched_gemm_float32_nchw_fma.cpp | 95 + dnn/src/cuda/conv_bias/opr_impl.h | 3 + .../implicit_gemm_int8_nchw4_dp4a.cpp | 3 + .../backward_data/implicit_gemm_int8_nchw_dp4a.cpp | 3 + .../backward_data/implicit_gemm_int8_nhwc_imma.cpp | 3 + dnn/src/cuda/cutlass/initialize_all.cu | 4 + dnn/src/cuda/cutlass/library.h | 3 + dnn/src/cuda/cutlass/library_internal.h | 21 + dnn/src/cuda/cutlass/operation_table.cpp | 4 + dnn/src/cuda/cutlass/operation_table.h | 14 + dnn/test/cuda/chanwise_convolution.cpp | 203 +- 21 files changed, 4706 insertions(+), 2626 deletions(-) create mode 100644 dnn/src/cuda/conv_bias/implicit_batched_gemm_float16_nchw_hmma.cpp create mode 100644 dnn/src/cuda/conv_bias/implicit_batched_gemm_float32_nchw_fma.cpp diff --git a/dnn/scripts/cutlass_generator/conv2d_operation.py b/dnn/scripts/cutlass_generator/conv2d_operation.py index 9e267fcb..5bb0bd8f 100644 --- a/dnn/scripts/cutlass_generator/conv2d_operation.py +++ b/dnn/scripts/cutlass_generator/conv2d_operation.py @@ -16,136 +16,191 @@ from library import * # class Conv2dOperation: - # - def __init__(self, conv_kind, conv_type, arch, tile_description, src, flt, bias, dst, element_epilogue, \ - epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4, \ - special_optimization = SpecialOptimizeDesc.NoneSpecialOpt, implicit_gemm_mode = ImplicitGemmMode.GemmNT, \ - without_shared_load = False, required_cuda_ver_major = 9, required_cuda_ver_minor = 2): - - self.operation_kind = OperationKind.Conv2d - self.conv_kind = conv_kind - self.arch = arch - self.tile_description = tile_description - self.conv_type = conv_type - self.src = src - self.flt = flt - self.bias = bias - self.dst = dst - self.element_epilogue = element_epilogue - self.epilogue_functor = epilogue_functor - self.swizzling_functor = swizzling_functor - self.special_optimization = special_optimization - self.implicit_gemm_mode = implicit_gemm_mode - self.without_shared_load = without_shared_load - self.required_cuda_ver_major = required_cuda_ver_major - self.required_cuda_ver_minor = required_cuda_ver_minor - - # - def accumulator_type(self): - accum = self.tile_description.math_instruction.element_accumulator - - return accum - - # - def core_name(self): - ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' - - intermediate_type = '' - - if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp: - inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) - if self.tile_description.math_instruction.element_a != self.flt.element and \ - self.tile_description.math_instruction.element_a != self.accumulator_type(): - intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] - else: - inst_shape = '' - - special_opt = '' - if self.special_optimization == SpecialOptimizeDesc.ConvFilterUnity: - special_opt = '_1x1' - elif self.special_optimization == SpecialOptimizeDesc.DeconvDoubleUpsampling: - special_opt = '_s2' - - reorder_k = '' - if self.without_shared_load: - reorder_k = '_roc' - - return "%s%s%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \ - inst_shape, intermediate_type, ConvKindNames[self.conv_kind], special_opt, \ - reorder_k, ShortEpilogueNames[self.epilogue_functor]) - - # - def extended_name(self): - if self.dst.element != self.tile_description.math_instruction.element_accumulator: - if self.src.element != self.flt.element: - extended_name = "${element_dst}_${core_name}_${element_src}_${element_flt}" - elif self.src.element == self.flt.element: - extended_name = "${element_dst}_${core_name}_${element_src}" - else: - if self.src.element != self.flt.element: - extended_name = "${core_name}_${element_src}_${element_flt}" - elif self.src.element == self.flt.element: - extended_name = "${core_name}_${element_src}" - - extended_name = SubstituteTemplate(extended_name, { - 'element_src': DataTypeNames[self.src.element], - 'element_flt': DataTypeNames[self.flt.element], - 'element_dst': DataTypeNames[self.dst.element], - 'core_name': self.core_name() - }) - - return extended_name - - # - def layout_name(self): - if self.src.layout == self.dst.layout: - layout_name = "${src_layout}_${flt_layout}" - else: - layout_name = "${src_layout}_${flt_layout}_${dst_layout}" - - layout_name = SubstituteTemplate(layout_name, { - 'src_layout': ShortLayoutTypeNames[self.src.layout], - 'flt_layout': ShortLayoutTypeNames[self.flt.layout], - 'dst_layout': ShortLayoutTypeNames[self.dst.layout], - }) - - return layout_name + # + def __init__( + self, + conv_kind, + conv_type, + arch, + tile_description, + src, + flt, + bias, + dst, + element_epilogue, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity4, + special_optimization=SpecialOptimizeDesc.NoneSpecialOpt, + implicit_gemm_mode=ImplicitGemmMode.GemmNT, + without_shared_load=False, + required_cuda_ver_major=9, + required_cuda_ver_minor=2, + ): + + self.operation_kind = OperationKind.Conv2d + self.conv_kind = conv_kind + self.arch = arch + self.tile_description = tile_description + self.conv_type = conv_type + self.src = src + self.flt = flt + self.bias = bias + self.dst = dst + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor + self.special_optimization = special_optimization + self.implicit_gemm_mode = implicit_gemm_mode + self.without_shared_load = without_shared_load + self.required_cuda_ver_major = required_cuda_ver_major + self.required_cuda_ver_minor = required_cuda_ver_minor + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + return accum + + # + def core_name(self): + """ The basic operation kind is prefixed with a letter indicating the accumulation type. """ + + intermediate_type = "" + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp: + inst_shape = "%d%d%d" % tuple( + self.tile_description.math_instruction.instruction_shape + ) + if ( + self.tile_description.math_instruction.element_a != self.flt.element + and self.tile_description.math_instruction.element_a + != self.accumulator_type() + ): + intermediate_type = DataTypeNames[ + self.tile_description.math_instruction.element_a + ] + else: + inst_shape = "" + + special_opt = "" + if self.special_optimization == SpecialOptimizeDesc.ConvFilterUnity: + special_opt = "_1x1" + elif self.special_optimization == SpecialOptimizeDesc.DeconvDoubleUpsampling: + special_opt = "_s2" + + reorder_k = "" + if self.without_shared_load: + reorder_k = "_roc" + + conv_type_name = "" + if self.conv_type == ConvType.DepthwiseConvolution: + conv_type_name = "dw" + + return "%s%s%s%s%s%s%s_%s" % ( + ShortDataTypeNames[self.accumulator_type()], + inst_shape, + intermediate_type, + conv_type_name, + ConvKindNames[self.conv_kind], + special_opt, + reorder_k, + ShortEpilogueNames[self.epilogue_functor], + ) + + # + def extended_name(self): + if ( + self.dst.element + != self.tile_description.math_instruction.element_accumulator + ): + if self.src.element != self.flt.element: + extended_name = ( + "${element_dst}_${core_name}_${element_src}_${element_flt}" + ) + elif self.src.element == self.flt.element: + extended_name = "${element_dst}_${core_name}_${element_src}" + else: + if self.src.element != self.flt.element: + extended_name = "${core_name}_${element_src}_${element_flt}" + elif self.src.element == self.flt.element: + extended_name = "${core_name}_${element_src}" + + extended_name = SubstituteTemplate( + extended_name, + { + "element_src": DataTypeNames[self.src.element], + "element_flt": DataTypeNames[self.flt.element], + "element_dst": DataTypeNames[self.dst.element], + "core_name": self.core_name(), + }, + ) + + return extended_name + + # + def layout_name(self): + if self.src.layout == self.dst.layout: + layout_name = "${src_layout}_${flt_layout}" + else: + layout_name = "${src_layout}_${flt_layout}_${dst_layout}" + + layout_name = SubstituteTemplate( + layout_name, + { + "src_layout": ShortLayoutTypeNames[self.src.layout], + "flt_layout": ShortLayoutTypeNames[self.flt.layout], + "dst_layout": ShortLayoutTypeNames[self.dst.layout], + }, + ) + + return layout_name + + # + def configuration_name(self): + """ The full procedural name indicates architecture, extended name, tile size, and layout. """ + + opcode_class_name = OpcodeClassNames[ + self.tile_description.math_instruction.opcode_class + ] + + warp_shape = [ + int( + self.tile_description.threadblock_shape[idx] + / self.tile_description.warp_count[idx] + ) + for idx in range(3) + ] + + threadblock = "%dx%dx%d_%dx%dx%d_%d" % ( + self.tile_description.threadblock_shape[0], + self.tile_description.threadblock_shape[1], + self.tile_description.threadblock_shape[2], + warp_shape[0], + warp_shape[1], + warp_shape[2], + self.tile_description.stages, + ) + + alignment = "align%dx%d" % (self.src.alignment, self.flt.alignment) + + configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${alignment}" + + return SubstituteTemplate( + configuration_name, + { + "opcode_class": opcode_class_name, + "extended_name": self.extended_name(), + "threadblock": threadblock, + "layout": self.layout_name(), + "alignment": alignment, + }, + ) + + # + def procedural_name(self): + """ The full procedural name indicates architecture, extended name, tile size, and layout. """ + return self.configuration_name() -# - def configuration_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - - opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] - - warp_shape = [int(self.tile_description.threadblock_shape[idx] / self.tile_description.warp_count[idx]) for idx in range(3)] - - - threadblock = "%dx%dx%d_%dx%dx%d_%d" % ( - self.tile_description.threadblock_shape[0], - self.tile_description.threadblock_shape[1], - self.tile_description.threadblock_shape[2], - warp_shape[0], - warp_shape[1], - warp_shape[2], - self.tile_description.stages, - ) - - configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}" - - return SubstituteTemplate( - configuration_name, - { - 'opcode_class': opcode_class_name, - 'extended_name': self.extended_name(), - 'threadblock': threadblock, - 'layout': self.layout_name(), - } - ) - - # - def procedural_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - return self.configuration_name() ################################################################################################### # @@ -153,9 +208,10 @@ class Conv2dOperation: # ################################################################################################### + class EmitConv2dInstance: - def __init__(self): - self.template = """ + def __init__(self): + self.template = """ // kernel instance "${operation_name}" generated by cutlass generator using Convolution = typename cutlass::conv::device::Convolution< @@ -191,54 +247,75 @@ using Convolution = ${without_shared_load}>; """ + def emit(self, operation): + + warp_shape = [ + int( + operation.tile_description.threadblock_shape[idx] + / operation.tile_description.warp_count[idx] + ) + for idx in range(3) + ] + + epilogue_vector_length = int( + min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128) + / DataTypeSize[operation.dst.element] + ) + + values = { + "operation_name": operation.procedural_name(), + "conv_type": ConvTypeTag[operation.conv_type], + "element_src": DataTypeTag[operation.src.element], + "layout_src": LayoutTag[operation.src.layout], + "element_flt": DataTypeTag[operation.flt.element], + "layout_flt": LayoutTag[operation.flt.layout], + "element_dst": DataTypeTag[operation.dst.element], + "layout_dst": LayoutTag[operation.dst.layout], + "element_bias": DataTypeTag[operation.bias.element], + "layout_bias": LayoutTag[operation.bias.layout], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class": OpcodeClassTag[ + operation.tile_description.math_instruction.opcode_class + ], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str( + operation.tile_description.math_instruction.instruction_shape[0] + ), + "instruction_shape_n": str( + operation.tile_description.math_instruction.instruction_shape[1] + ), + "instruction_shape_k": str( + operation.tile_description.math_instruction.instruction_shape[2] + ), + "epilogue_vector_length": str(epilogue_vector_length), + "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor], + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), + "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor], + "stages": str(operation.tile_description.stages), + "alignment_src": str(operation.src.alignment), + "alignment_filter": str(operation.flt.alignment), + "special_optimization": SpecialOptimizeDescTag[ + operation.special_optimization + ], + "math_operator": MathOperationTag[ + operation.tile_description.math_instruction.math_operation + ], + "implicit_gemm_mode": ImplicitGemmModeTag[operation.implicit_gemm_mode], + "without_shared_load": str(operation.without_shared_load).lower(), + } + + return SubstituteTemplate(self.template, values) - def emit(self, operation): - - warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)] - - epilogue_vector_length = int(min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128) / DataTypeSize[operation.dst.element]) - - values = { - 'operation_name': operation.procedural_name(), - 'conv_type': ConvTypeTag[operation.conv_type], - 'element_src': DataTypeTag[operation.src.element], - 'layout_src': LayoutTag[operation.src.layout], - 'element_flt': DataTypeTag[operation.flt.element], - 'layout_flt': LayoutTag[operation.flt.layout], - 'element_dst': DataTypeTag[operation.dst.element], - 'layout_dst': LayoutTag[operation.dst.layout], - 'element_bias': DataTypeTag[operation.bias.element], - 'layout_bias': LayoutTag[operation.bias.layout], - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_vector_length': str(epilogue_vector_length), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], - 'stages': str(operation.tile_description.stages), - 'alignment_src': str(operation.src.alignment), - 'alignment_filter': str(operation.flt.alignment), - 'special_optimization': SpecialOptimizeDescTag[operation.special_optimization], - 'math_operator': MathOperationTag[operation.tile_description.math_instruction.math_operation], - 'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode], - 'without_shared_load': str(operation.without_shared_load).lower() - } - - return SubstituteTemplate(self.template, values) class EmitDeconvInstance: - def __init__(self): - self.template = """ + def __init__(self): + self.template = """ // kernel instance "${operation_name}" generated by cutlass generator using Deconvolution = typename cutlass::conv::device::Deconvolution< @@ -273,49 +350,69 @@ using Deconvolution = ${implicit_gemm_mode}>; """ - - def emit(self, operation): - - warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)] - - epilogue_vector_length = int(min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128) / DataTypeSize[operation.dst.element]) - - values = { - 'operation_name': operation.procedural_name(), - 'conv_type': ConvTypeTag[operation.conv_type], - 'element_src': DataTypeTag[operation.src.element], - 'layout_src': LayoutTag[operation.src.layout], - 'element_flt': DataTypeTag[operation.flt.element], - 'layout_flt': LayoutTag[operation.flt.layout], - 'element_dst': DataTypeTag[operation.dst.element], - 'layout_dst': LayoutTag[operation.dst.layout], - 'element_bias': DataTypeTag[operation.bias.element], - 'layout_bias': LayoutTag[operation.bias.layout], - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_vector_length': str(epilogue_vector_length), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], - 'stages': str(operation.tile_description.stages), - 'alignment_src': str(operation.src.alignment), - 'alignment_filter': str(operation.flt.alignment), - 'special_optimization': SpecialOptimizeDescTag[operation.special_optimization], - 'math_operator': MathOperationTag[operation.tile_description.math_instruction.math_operation], - 'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode] - } - - return SubstituteTemplate(self.template, values) + def emit(self, operation): + + warp_shape = [ + int( + operation.tile_description.threadblock_shape[idx] + / operation.tile_description.warp_count[idx] + ) + for idx in range(3) + ] + + epilogue_vector_length = int( + min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128) + / DataTypeSize[operation.dst.element] + ) + + values = { + "operation_name": operation.procedural_name(), + "conv_type": ConvTypeTag[operation.conv_type], + "element_src": DataTypeTag[operation.src.element], + "layout_src": LayoutTag[operation.src.layout], + "element_flt": DataTypeTag[operation.flt.element], + "layout_flt": LayoutTag[operation.flt.layout], + "element_dst": DataTypeTag[operation.dst.element], + "layout_dst": LayoutTag[operation.dst.layout], + "element_bias": DataTypeTag[operation.bias.element], + "layout_bias": LayoutTag[operation.bias.layout], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class": OpcodeClassTag[ + operation.tile_description.math_instruction.opcode_class + ], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str( + operation.tile_description.math_instruction.instruction_shape[0] + ), + "instruction_shape_n": str( + operation.tile_description.math_instruction.instruction_shape[1] + ), + "instruction_shape_k": str( + operation.tile_description.math_instruction.instruction_shape[2] + ), + "epilogue_vector_length": str(epilogue_vector_length), + "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor], + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), + "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor], + "stages": str(operation.tile_description.stages), + "alignment_src": str(operation.src.alignment), + "alignment_filter": str(operation.flt.alignment), + "special_optimization": SpecialOptimizeDescTag[ + operation.special_optimization + ], + "math_operator": MathOperationTag[ + operation.tile_description.math_instruction.math_operation + ], + "implicit_gemm_mode": ImplicitGemmModeTag[operation.implicit_gemm_mode], + } + + return SubstituteTemplate(self.template, values) ################################################################################################### @@ -325,104 +422,209 @@ using Deconvolution = ################################################################################################### # -def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_layout, dst_type, min_cc, src_align = 32, flt_align = 32, dst_align = 32, \ - use_special_optimization = SpecialOptimizeDesc.NoneSpecialOpt, implicit_gemm_mode = ImplicitGemmMode.GemmNT, without_shared_load = False, \ - required_cuda_ver_major = 9, required_cuda_ver_minor = 2): - operations = [] - - element_epilogue = DataType.f32 - if conv_kind == ConvKind.Fprop: - if implicit_gemm_mode == ImplicitGemmMode.GemmTN: - swizzling_functor = SwizzlingFunctor.ConvFpropTrans - else: - swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx - else: - if implicit_gemm_mode == ImplicitGemmMode.GemmTN: - swizzling_functor = SwizzlingFunctor.ConvDgradTrans - else: - swizzling_functor = SwizzlingFunctor.ConvDgradNCxHWx - - # skip rule - def filter_tile_with_layout(tile: TileDescription, layout: LayoutType) -> bool: - return layout == LayoutType.TensorNC32HW32 and \ - tile.threadblock_shape[0] % 32 != 0 - - # rule for bias_type and epilogues - def get_bias_type_and_epilogues(tile: TileDescription, \ - out_dtype: DataType) -> Tuple[DataType, List[EpilogueFunctor]]: - if tile.math_instruction.element_accumulator == DataType.s32 and \ - out_dtype != DataType.f32: - bias_type = DataType.s32 - if tile.math_instruction.element_b == DataType.u4: - epilogues = [EpilogueFunctor.BiasAddLinearCombinationClamp, EpilogueFunctor.BiasAddLinearCombinationReluClamp] - else: - epilogues = [EpilogueFunctor.BiasAddLinearCombinationClamp, EpilogueFunctor.BiasAddLinearCombinationReluClamp, \ - EpilogueFunctor.BiasAddLinearCombinationHSwishClamp] - elif tile.math_instruction.element_accumulator == DataType.f32 or \ - out_dtype == DataType.f32: - bias_type = DataType.f32 - epilogues = [EpilogueFunctor.BiasAddLinearCombination, EpilogueFunctor.BiasAddLinearCombinationRelu, \ - EpilogueFunctor.BiasAddLinearCombinationHSwish] - return bias_type, epilogues - - # rule for filter alignment - def get_flt_align(tile: TileDescription) -> int: - nonlocal flt_align - if tile.math_instruction.opcode_class == OpcodeClass.Simt \ - and tile.math_instruction.element_accumulator == DataType.s32: - thread_num = tile.warp_count[0] * tile.warp_count[1] * tile.warp_count[2] * 32 - flt_block = tile.threadblock_shape[0] * tile.threadblock_shape[2] \ - * DataTypeSize[tile.math_instruction.element_a] - load_per_thread = flt_block//thread_num - if load_per_thread >= 128: - flt_align = 128 - elif load_per_thread >= 64: - flt_align = 64 - else: - assert load_per_thread >= 32 - flt_align = 32 - return flt_align - - def get_dst_align(tile: TileDescription, out_layout: LayoutType) -> int: - nonlocal dst_align - if tile.math_instruction.opcode_class == OpcodeClass.TensorOp \ - and dst_layout == LayoutType.TensorNC4HW4: - dst_align = 32 - return dst_align - - def filter_epilogue_with_conv_kind(epilogue: EpilogueFunctor, conv_kind: ConvKind) -> bool: - return conv_kind == ConvKind.Dgrad \ - and epilogue != EpilogueFunctor.BiasAddLinearCombinationClamp - - # loop over all tile descriptions - for tile in tile_descriptions: - if filter_tile_with_layout(tile, dst_layout): - continue - - bias_type, epilogues = get_bias_type_and_epilogues(tile, dst_type) - - flt_align = get_flt_align(tile) - - dst_align = get_dst_align(tile, dst_layout) - - for epilogue in epilogues: - if filter_epilogue_with_conv_kind(epilogue, conv_kind): - continue - - if dst_type == DataType.f32: - bias_type = DataType.f32 - # - src = TensorDescription(tile.math_instruction.element_b, src_layout, int(src_align / DataTypeSize[tile.math_instruction.element_b])) - flt = TensorDescription(tile.math_instruction.element_a, flt_layout, int(flt_align / DataTypeSize[tile.math_instruction.element_a])) - bias = TensorDescription(bias_type, dst_layout, max(1, int(32 / DataTypeSize[bias_type]))) - dst = TensorDescription(dst_type, dst_layout, int(dst_align / DataTypeSize[dst_type])) - - new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, SpecialOptimizeDesc.NoneSpecialOpt, implicit_gemm_mode, without_shared_load, required_cuda_ver_major, required_cuda_ver_minor) - operations.append(new_operation) - if use_special_optimization != SpecialOptimizeDesc.NoneSpecialOpt: - new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, use_special_optimization , implicit_gemm_mode, without_shared_load, required_cuda_ver_major, required_cuda_ver_minor) - operations.append(new_operation) - return operations +def GenerateConv2d( + conv_type, + conv_kind, + tile_descriptions, + src_layout, + flt_layout, + dst_layout, + dst_type, + min_cc, + src_align=32, + flt_align=32, + dst_align=32, + use_special_optimization=SpecialOptimizeDesc.NoneSpecialOpt, + implicit_gemm_mode=ImplicitGemmMode.GemmNT, + without_shared_load=False, + required_cuda_ver_major=9, + required_cuda_ver_minor=2, +): + operations = [] + + element_epilogue = DataType.f32 + if conv_type == ConvType.DepthwiseConvolution: + if conv_kind == ConvKind.Fprop: + swizzling_functor = SwizzlingFunctor.DepthwiseConvolutionFprop + elif conv_kind == ConvKind.Dgrad: + swizzling_functor = SwizzlingFunctor.DepthwiseConvolutionDgrad + else: + assert conv_kind == ConvKind.Wgrad + swizzling_functor = SwizzlingFunctor.DepthwiseConvolutionWgrad + elif conv_type == ConvType.Convolution: + if conv_kind == ConvKind.Fprop: + if implicit_gemm_mode == ImplicitGemmMode.GemmTN: + swizzling_functor = SwizzlingFunctor.ConvFpropTrans + else: + swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx + else: + if implicit_gemm_mode == ImplicitGemmMode.GemmTN: + swizzling_functor = SwizzlingFunctor.ConvDgradTrans + else: + swizzling_functor = SwizzlingFunctor.ConvDgradNCxHWx + + # skip rule + def filter_tile_with_layout(tile: TileDescription, layout: LayoutType) -> bool: + return ( + layout == LayoutType.TensorNC32HW32 and tile.threadblock_shape[0] % 32 != 0 + ) + + # rule for bias_type and epilogues + def get_bias_type_and_epilogues( + tile: TileDescription, out_dtype: DataType + ) -> Tuple[DataType, List[EpilogueFunctor]]: + if ( + tile.math_instruction.element_accumulator == DataType.s32 + and out_dtype != DataType.f32 + ): + bias_type = DataType.s32 + if tile.math_instruction.element_b == DataType.u4: + epilogues = [ + EpilogueFunctor.BiasAddLinearCombinationClamp, + EpilogueFunctor.BiasAddLinearCombinationReluClamp, + ] + else: + epilogues = [ + EpilogueFunctor.BiasAddLinearCombinationClamp, + EpilogueFunctor.BiasAddLinearCombinationReluClamp, + EpilogueFunctor.BiasAddLinearCombinationHSwishClamp, + ] + elif ( + tile.math_instruction.element_accumulator == DataType.f32 + or tile.math_instruction.element_accumulator == DataType.f16 + ) or ( + tile.math_instruction.element_accumulator == DataType.s32 + and out_dtype == DataType.f32 + ): + bias_type = out_dtype + epilogues = [ + EpilogueFunctor.BiasAddLinearCombination, + EpilogueFunctor.BiasAddLinearCombinationRelu, + ] + if conv_type == ConvType.Convolution: + epilogues.append(EpilogueFunctor.BiasAddLinearCombinationHSwish) + else: + assert False, "invalid path" + return bias_type, epilogues + + # rule for filter alignment + def get_flt_align(tile: TileDescription) -> int: + nonlocal flt_align + if ( + tile.math_instruction.opcode_class == OpcodeClass.Simt + and tile.math_instruction.element_accumulator == DataType.s32 + ): + thread_num = ( + tile.warp_count[0] * tile.warp_count[1] * tile.warp_count[2] * 32 + ) + flt_block = ( + tile.threadblock_shape[0] + * tile.threadblock_shape[2] + * DataTypeSize[tile.math_instruction.element_a] + ) + load_per_thread = flt_block // thread_num + if load_per_thread >= 128: + flt_align = 128 + elif load_per_thread >= 64: + flt_align = 64 + else: + assert load_per_thread >= 32 + flt_align = 32 + return flt_align + + def get_dst_align(tile: TileDescription, out_layout: LayoutType) -> int: + nonlocal dst_align + if ( + tile.math_instruction.opcode_class == OpcodeClass.TensorOp + and dst_layout == LayoutType.TensorNC4HW4 + ): + dst_align = 32 + return dst_align + + def filter_epilogue_with_conv_kind( + epilogue: EpilogueFunctor, conv_kind: ConvKind + ) -> bool: + return ( + conv_kind == ConvKind.Dgrad + and epilogue != EpilogueFunctor.BiasAddLinearCombinationClamp + ) + + # loop over all tile descriptions + for tile in tile_descriptions: + if filter_tile_with_layout(tile, dst_layout): + continue + + bias_type, epilogues = get_bias_type_and_epilogues(tile, dst_type) + + flt_align = get_flt_align(tile) + + dst_align = get_dst_align(tile, dst_layout) + + for epilogue in epilogues: + if filter_epilogue_with_conv_kind(epilogue, conv_kind): + continue + + if dst_type == DataType.f32: + bias_type = DataType.f32 + # + src = TensorDescription( + tile.math_instruction.element_b, + src_layout, + int(src_align / DataTypeSize[tile.math_instruction.element_b]), + ) + flt = TensorDescription( + tile.math_instruction.element_a, + flt_layout, + int(flt_align / DataTypeSize[tile.math_instruction.element_a]), + ) + bias = TensorDescription( + bias_type, dst_layout, max(1, int(32 / DataTypeSize[bias_type])) + ) + dst = TensorDescription( + dst_type, dst_layout, int(dst_align / DataTypeSize[dst_type]) + ) + + new_operation = Conv2dOperation( + conv_kind, + conv_type, + min_cc, + tile, + src, + flt, + bias, + dst, + element_epilogue, + epilogue, + swizzling_functor, + SpecialOptimizeDesc.NoneSpecialOpt, + implicit_gemm_mode, + without_shared_load, + required_cuda_ver_major, + required_cuda_ver_minor, + ) + operations.append(new_operation) + if use_special_optimization != SpecialOptimizeDesc.NoneSpecialOpt: + new_operation = Conv2dOperation( + conv_kind, + conv_type, + min_cc, + tile, + src, + flt, + bias, + dst, + element_epilogue, + epilogue, + swizzling_functor, + use_special_optimization, + implicit_gemm_mode, + without_shared_load, + required_cuda_ver_major, + required_cuda_ver_minor, + ) + operations.append(new_operation) + return operations + ################################################################################################### # @@ -430,14 +632,17 @@ def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_lay # ################################################################################################### + class EmitConv2dConfigurationLibrary: - def __init__(self, operation_path, configuration_name): - self.configuration_name = configuration_name - self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name) + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join( + operation_path, "%s.cu" % configuration_name + ) - self.instance_emitter = EmitConv2dInstance() + self.instance_emitter = EmitConv2dInstance() - self.instance_template = """ + self.instance_template = """ ${operation_instance} // Derived class @@ -447,7 +652,7 @@ struct ${operation_name} : /////////////////////////////////////////////////////////////////////////////////////////////////// """ - self.header_template = """ + self.header_template = """ /* Generated by conv2d_operation.py - Do not edit. */ @@ -464,7 +669,7 @@ struct ${operation_name} : /////////////////////////////////////////////////////////////////////////////////////////////////// """ - self.configuration_header = """ + self.configuration_header = """ namespace cutlass { namespace library { @@ -474,7 +679,7 @@ void initialize_${configuration_name}(Manifest &manifest) { """ - self.configuration_instance = """ + self.configuration_instance = """ using Operation_${operation_name} = cutlass::conv::device::ImplicitGemmConvolution< ${operation_name}>; @@ -484,10 +689,10 @@ void initialize_${configuration_name}(Manifest &manifest) { """ - self.configuration_epilogue = """ + self.configuration_epilogue = """ } """ - self.epilogue_template = """ + self.epilogue_template = """ /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -498,40 +703,56 @@ void initialize_${configuration_name}(Manifest &manifest) { """ - # - def __enter__(self): - self.configuration_file = open(self.configuration_path, "w") - self.configuration_file.write(SubstituteTemplate(self.header_template, { - 'configuration_name': self.configuration_name - })) - self.operations = [] - return self - - # - def emit(self, operation): - self.operations.append(operation) - self.configuration_file.write(SubstituteTemplate(self.instance_template, { - 'configuration_name': self.configuration_name, - 'operation_name': operation.procedural_name(), - 'operation_instance': self.instance_emitter.emit(operation) - })) - - # - def __exit__(self, exception_type, exception_value, traceback): - - self.configuration_file.write(SubstituteTemplate(self.configuration_header, { - 'configuration_name': self.configuration_name - })) - - for operation in self.operations: - self.configuration_file.write(SubstituteTemplate(self.configuration_instance, { - 'configuration_name': self.configuration_name, - 'operation_name': operation.procedural_name() - })) - - self.configuration_file.write(self.configuration_epilogue) - self.configuration_file.write(self.epilogue_template) - self.configuration_file.close() + # + def __enter__(self): + self.configuration_file = open(self.configuration_path, "w") + self.configuration_file.write( + SubstituteTemplate( + self.header_template, {"configuration_name": self.configuration_name} + ) + ) + self.operations = [] + return self + + # + def emit(self, operation): + self.operations.append(operation) + self.configuration_file.write( + SubstituteTemplate( + self.instance_template, + { + "configuration_name": self.configuration_name, + "operation_name": operation.procedural_name(), + "operation_instance": self.instance_emitter.emit(operation), + }, + ) + ) + + # + def __exit__(self, exception_type, exception_value, traceback): + + self.configuration_file.write( + SubstituteTemplate( + self.configuration_header, + {"configuration_name": self.configuration_name}, + ) + ) + + for operation in self.operations: + self.configuration_file.write( + SubstituteTemplate( + self.configuration_instance, + { + "configuration_name": self.configuration_name, + "operation_name": operation.procedural_name(), + }, + ) + ) + + self.configuration_file.write(self.configuration_epilogue) + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + ################################################################################################### ################################################################################################### @@ -540,21 +761,22 @@ void initialize_${configuration_name}(Manifest &manifest) { # ################################################################################################### -class EmitConvSingleKernelWrapper(): - def __init__(self, kernel_path, operation, short_path=False): - self.kernel_path = kernel_path - self.operation = operation - self.short_path = short_path - - if self.operation.conv_kind == ConvKind.Fprop: - self.instance_emitter = EmitConv2dInstance() - self.convolution_name = "Convolution" - else: - assert self.operation.conv_kind == ConvKind.Dgrad - self.instance_emitter = EmitDeconvInstance() - self.convolution_name = "Deconvolution" - - self.header_template = """ + +class EmitConvSingleKernelWrapper: + def __init__(self, kernel_path, operation, short_path=False): + self.kernel_path = kernel_path + self.operation = operation + self.short_path = short_path + + if self.operation.conv_kind == ConvKind.Fprop: + self.instance_emitter = EmitConv2dInstance() + self.convolution_name = "Convolution" + else: + assert self.operation.conv_kind == ConvKind.Dgrad + self.instance_emitter = EmitDeconvInstance() + self.convolution_name = "Deconvolution" + + self.header_template = """ #if __CUDACC_VER_MAJOR__ > ${required_cuda_ver_major} || (__CUDACC_VER_MAJOR__ == ${required_cuda_ver_major} && __CUDACC_VER_MINOR__ >= ${required_cuda_ver_minor}) // ignore warning of cutlass #pragma GCC diagnostic push @@ -568,11 +790,11 @@ class EmitConvSingleKernelWrapper(): #include "src/cuda/cutlass/manifest.h" #include "src/cuda/cutlass/convolution_operation.h" """ - self.instance_template = """ + self.instance_template = """ ${operation_instance} """ - self.manifest_template = """ + self.manifest_template = """ namespace cutlass { namespace library { @@ -586,44 +808,60 @@ void initialize_${operation_name}(Manifest &manifest) { } // namespace cutlass """ - self.epilogue_template = """ + self.epilogue_template = """ #pragma GCC diagnostic pop #endif """ - # - def __enter__(self): - if self.short_path: - self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt) - GlobalCnt.cnt += 1 - else: - self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) - self.kernel_file = open(self.kernel_path, "w") - self.kernel_file.write(SubstituteTemplate(self.header_template, { - 'required_cuda_ver_major': str(self.operation.required_cuda_ver_major), - 'required_cuda_ver_minor': str(self.operation.required_cuda_ver_minor), - })) - return self - - # - def emit(self): - self.kernel_file.write(SubstituteTemplate(self.instance_template, { - 'operation_instance': self.instance_emitter.emit(self.operation), - })) - - # emit manifest helper - manifest = SubstituteTemplate(self.manifest_template, { - 'operation_name': self.operation.procedural_name(), - 'convolution_name': self.convolution_name - }) - self.kernel_file.write(manifest) - - # - def __exit__(self, exception_type, exception_value, traceback): - self.kernel_file.write(self.epilogue_template) - self.kernel_file.close() + # + def __enter__(self): + if self.short_path: + self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt) + GlobalCnt.cnt += 1 + else: + self.kernel_path = os.path.join( + self.kernel_path, "%s.cu" % self.operation.procedural_name() + ) + self.kernel_file = open(self.kernel_path, "w") + self.kernel_file.write( + SubstituteTemplate( + self.header_template, + { + "required_cuda_ver_major": str( + self.operation.required_cuda_ver_major + ), + "required_cuda_ver_minor": str( + self.operation.required_cuda_ver_minor + ), + }, + ) + ) + return self + + # + def emit(self): + self.kernel_file.write( + SubstituteTemplate( + self.instance_template, + {"operation_instance": self.instance_emitter.emit(self.operation)}, + ) + ) + + # emit manifest helper + manifest = SubstituteTemplate( + self.manifest_template, + { + "operation_name": self.operation.procedural_name(), + "convolution_name": self.convolution_name, + }, + ) + self.kernel_file.write(manifest) + + # + def __exit__(self, exception_type, exception_value, traceback): + self.kernel_file.write(self.epilogue_template) + self.kernel_file.close() ################################################################################################### ################################################################################################### - diff --git a/dnn/scripts/cutlass_generator/gemm_operation.py b/dnn/scripts/cutlass_generator/gemm_operation.py index 47a5a82f..a1583fdb 100644 --- a/dnn/scripts/cutlass_generator/gemm_operation.py +++ b/dnn/scripts/cutlass_generator/gemm_operation.py @@ -21,140 +21,188 @@ from library import * # class GemmOperation: - # - def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ - epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \ - required_cuda_ver_major = 9, required_cuda_ver_minor = 2): - - self.operation_kind = OperationKind.Gemm - self.arch = arch - self.tile_description = tile_description - self.gemm_kind = gemm_kind - self.A = A - self.B = B - self.C = C - self.element_epilogue = element_epilogue - self.epilogue_functor = epilogue_functor - self.swizzling_functor = swizzling_functor - self.required_cuda_ver_major = required_cuda_ver_major - self.required_cuda_ver_minor = required_cuda_ver_minor - - - # - def is_complex(self): - complex_operators = [ - MathOperation.multiply_add_complex, - MathOperation.multiply_add_complex_gaussian - ] - return self.tile_description.math_instruction.math_operation in complex_operators - - # - def is_split_k_parallel(self): - return self.gemm_kind == GemmKind.SplitKParallel - - # - def is_planar_complex(self): - return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray) - - # - def accumulator_type(self): - accum = self.tile_description.math_instruction.element_accumulator - - if self.is_complex(): - return get_complex_from_real(accum) - - return accum - - # - def short_math_name(self): - if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: - return "g%s" % ShortDataTypeNames[self.accumulator_type()] - return ShortDataTypeNames[self.accumulator_type()] - - - # - def core_name(self): - ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' - - inst_shape = '' - inst_operation = '' - intermediate_type = '' - - math_operations_map = { - MathOperation.xor_popc: 'xor', - } - - if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ - self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp: - - math_op = self.tile_description.math_instruction.math_operation - math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' - - inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) - inst_shape += math_op_string - - if self.tile_description.math_instruction.element_a != self.A.element and \ - self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: - intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] - - return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind]) - - # - def extended_name(self): - ''' Append data types if they differ from compute type. ''' - if self.is_complex(): - extended_name = "${core_name}" - else: - if self.C.element != self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${element_c}_${core_name}_${element_a}" - elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${core_name}_${element_a}" - else: - extended_name = "${core_name}" - - extended_name = SubstituteTemplate(extended_name, { - 'element_a': DataTypeNames[self.A.element], - 'element_c': DataTypeNames[self.C.element], - 'core_name': self.core_name() - }) - - return extended_name - - # - def layout_name(self): - if self.is_complex() or self.is_planar_complex(): - return "%s%s" % ( - ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], - ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)] - ) - return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) - - # - def procedural_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - threadblock = self.tile_description.procedural_name() - - opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] - - alignment = max([self.A.alignment, self.B.alignment, self.C.alignment]) - - return SubstituteTemplate( - "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}", - { - 'opcode_class': opcode_class_name, - 'extended_name': self.extended_name(), - 'threadblock': threadblock, - 'layout': self.layout_name(), - 'alignment': "%d" % self.A.alignment, - } - ) + # + def __init__( + self, + gemm_kind, + arch, + tile_description, + A, + B, + C, + element_epilogue, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity8, + required_cuda_ver_major=9, + required_cuda_ver_minor=2, + ): + + self.operation_kind = OperationKind.Gemm + self.arch = arch + self.tile_description = tile_description + self.gemm_kind = gemm_kind + self.A = A + self.B = B + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor + self.required_cuda_ver_major = required_cuda_ver_major + self.required_cuda_ver_minor = required_cuda_ver_minor + + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + ] + return ( + self.tile_description.math_instruction.math_operation in complex_operators + ) + + # + def is_split_k_parallel(self): + return self.gemm_kind == GemmKind.SplitKParallel + + # + def is_planar_complex(self): + return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray) + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def short_math_name(self): + if ( + self.tile_description.math_instruction.math_operation + == MathOperation.multiply_add_complex_gaussian + ): + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + # + def core_name(self): + """ The basic operation kind is prefixed with a letter indicating the accumulation type. """ + + inst_shape = "" + inst_operation = "" + intermediate_type = "" + + math_operations_map = {MathOperation.xor_popc: "xor"} + + if ( + self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp + or self.tile_description.math_instruction.opcode_class + == OpcodeClass.WmmaTensorOp + ): + + math_op = self.tile_description.math_instruction.math_operation + math_op_string = ( + math_operations_map[math_op] + if math_op in math_operations_map.keys() + else "" + ) + + inst_shape = "%d%d%d" % tuple( + self.tile_description.math_instruction.instruction_shape + ) + inst_shape += math_op_string + + if ( + self.tile_description.math_instruction.element_a != self.A.element + and self.tile_description.math_instruction.element_a + != self.tile_description.math_instruction.element_accumulator + ): + intermediate_type = DataTypeNames[ + self.tile_description.math_instruction.element_a + ] + + return "%s%s%s%s" % ( + self.short_math_name(), + inst_shape, + intermediate_type, + GemmKindNames[self.gemm_kind], + ) + + # + def extended_name(self): + """ Append data types if they differ from compute type. """ + if self.is_complex(): + extended_name = "${core_name}" + else: + if ( + self.C.element + != self.tile_description.math_instruction.element_accumulator + and self.A.element + != self.tile_description.math_instruction.element_accumulator + ): + extended_name = "${element_c}_${core_name}_${element_a}" + elif ( + self.C.element + == self.tile_description.math_instruction.element_accumulator + and self.A.element + != self.tile_description.math_instruction.element_accumulator + ): + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate( + extended_name, + { + "element_a": DataTypeNames[self.A.element], + "element_c": DataTypeNames[self.C.element], + "core_name": self.core_name(), + }, + ) + + return extended_name + + # + def layout_name(self): + if self.is_complex() or self.is_planar_complex(): + return "%s%s" % ( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)], + ) + return "%s%s" % ( + ShortLayoutTypeNames[self.A.layout], + ShortLayoutTypeNames[self.B.layout], + ) + + # + def procedural_name(self): + """ The full procedural name indicates architecture, extended name, tile size, and layout. """ + threadblock = self.tile_description.procedural_name() + + opcode_class_name = OpcodeClassNames[ + self.tile_description.math_instruction.opcode_class + ] + + alignment = max([self.A.alignment, self.B.alignment, self.C.alignment]) + + return SubstituteTemplate( + "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}", + { + "opcode_class": opcode_class_name, + "extended_name": self.extended_name(), + "threadblock": threadblock, + "layout": self.layout_name(), + "alignment": "%d" % self.A.alignment, + }, + ) + + # + def configuration_name(self): + """ The full procedural name indicates architecture, extended name, tile size, and layout. """ + return self.procedural_name() - # - def configuration_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - return self.procedural_name() ################################################################################################### # @@ -164,127 +212,219 @@ class GemmOperation: # class GemvBatchedStridedOperation: - # - def __init__(self, gemm_kind, arch, math_inst, threadblock_shape, thread_shape, A, B, C, \ - required_cuda_ver_major = 9, required_cuda_ver_minor = 2): - - self.operation_kind = OperationKind.Gemm - self.arch = arch - self.gemm_kind = gemm_kind - self.math_instruction = math_inst - self.threadblock_shape = threadblock_shape - self.thread_shape = thread_shape - self.A = A - self.B = B - self.C = C - self.required_cuda_ver_major = required_cuda_ver_major - self.required_cuda_ver_minor = required_cuda_ver_minor - - # - def accumulator_type(self): - accum = self.math_instruction.element_accumulator - - return accum - - # - def short_math_name(self): - return ShortDataTypeNames[self.accumulator_type()] - - - # - def core_name(self): - ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' - - return "%s%s" % (self.short_math_name(), \ - GemmKindNames[self.gemm_kind]) - - # - def extended_name(self): - ''' Append data types if they differ from compute type. ''' - if self.C.element != self.math_instruction.element_accumulator and \ - self.A.element != self.math_instruction.element_accumulator: - extended_name = "${element_c}_${core_name}_${element_a}" - elif self.C.element == self.math_instruction.element_accumulator and \ - self.A.element != self.math_instruction.element_accumulator: - extended_name = "${core_name}_${element_a}" - else: - extended_name = "${core_name}" - - extended_name = SubstituteTemplate(extended_name, { - 'element_a': DataTypeNames[self.A.element], - 'element_c': DataTypeNames[self.C.element], - 'core_name': self.core_name() - }) - - return extended_name - - # - def layout_name(self): - return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) - - # - def procedural_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - threadblock = "%dx%d_%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2]) - - opcode_class_name = OpcodeClassNames[self.math_instruction.opcode_class] - - alignment_a = self.A.alignment - alignment_b = self.B.alignment - - return SubstituteTemplate( - "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment_a}x${alignment_b}", - { - 'opcode_class': opcode_class_name, - 'extended_name': self.extended_name(), - 'threadblock': threadblock, - 'layout': self.layout_name(), - 'alignment_a': "%d" % alignment_a, - 'alignment_b': "%d" % alignment_b, - } - ) + # + def __init__( + self, + gemm_kind, + arch, + math_inst, + threadblock_shape, + thread_shape, + A, + B, + C, + required_cuda_ver_major=9, + required_cuda_ver_minor=2, + ): + + self.operation_kind = OperationKind.Gemm + self.arch = arch + self.gemm_kind = gemm_kind + self.math_instruction = math_inst + self.threadblock_shape = threadblock_shape + self.thread_shape = thread_shape + self.A = A + self.B = B + self.C = C + self.required_cuda_ver_major = required_cuda_ver_major + self.required_cuda_ver_minor = required_cuda_ver_minor + + # + def accumulator_type(self): + accum = self.math_instruction.element_accumulator + + return accum + + # + def short_math_name(self): + return ShortDataTypeNames[self.accumulator_type()] + + # + def core_name(self): + """ The basic operation kind is prefixed with a letter indicating the accumulation type. """ + + return "%s%s" % (self.short_math_name(), GemmKindNames[self.gemm_kind]) + + # + def extended_name(self): + """ Append data types if they differ from compute type. """ + if ( + self.C.element != self.math_instruction.element_accumulator + and self.A.element != self.math_instruction.element_accumulator + ): + extended_name = "${element_c}_${core_name}_${element_a}" + elif ( + self.C.element == self.math_instruction.element_accumulator + and self.A.element != self.math_instruction.element_accumulator + ): + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate( + extended_name, + { + "element_a": DataTypeNames[self.A.element], + "element_c": DataTypeNames[self.C.element], + "core_name": self.core_name(), + }, + ) + + return extended_name + + # + def layout_name(self): + return "%s%s" % ( + ShortLayoutTypeNames[self.A.layout], + ShortLayoutTypeNames[self.B.layout], + ) + + # + def procedural_name(self): + """ The full procedural name indicates architecture, extended name, tile size, and layout. """ + threadblock = "%dx%d_%d" % ( + self.threadblock_shape[0], + self.threadblock_shape[1], + self.threadblock_shape[2], + ) + + opcode_class_name = OpcodeClassNames[self.math_instruction.opcode_class] + + alignment_a = self.A.alignment + alignment_b = self.B.alignment + + return SubstituteTemplate( + "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment_a}x${alignment_b}", + { + "opcode_class": opcode_class_name, + "extended_name": self.extended_name(), + "threadblock": threadblock, + "layout": self.layout_name(), + "alignment_a": "%d" % alignment_a, + "alignment_b": "%d" % alignment_b, + }, + ) + + # + def configuration_name(self): + """ The full procedural name indicates architecture, extended name, tile size, and layout. """ + return self.procedural_name() - # - def configuration_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - return self.procedural_name() # -def GeneratesGemm(tile, data_type, layout_a, layout_b, layout_c, min_cc, align_a = 32, align_b = 32, align_c = 32, required_cuda_ver_major = 9, required_cuda_ver_minor = 2): - operations = [] - swizzling_functor = SwizzlingFunctor.Identity1 - - element_a, element_b, element_c, element_epilogue = data_type - - if tile.math_instruction.element_accumulator == DataType.s32: - epilogues = [EpilogueFunctor.LinearCombinationClamp] - else: - assert tile.math_instruction.element_accumulator == DataType.f32 or \ - tile.math_instruction.element_accumulator == DataType.f16 - epilogues = [EpilogueFunctor.LinearCombination] - - for epilogue in epilogues: - A = TensorDescription(element_a, layout_a, int(align_a//DataTypeSize[element_a])) - B = TensorDescription(element_b, layout_b, int(align_b//DataTypeSize[element_b])) - C = TensorDescription(element_c, layout_c, int(align_c//DataTypeSize[element_c])) - operations.append(GemmOperation(GemmKind.Gemm, min_cc, tile, A, B, C, \ - element_epilogue, epilogue, swizzling_functor, \ - required_cuda_ver_major, required_cuda_ver_minor)) - operations.append(GemmOperation(GemmKind.SplitKParallel, min_cc, tile, A, B, C, \ - element_epilogue, epilogue, swizzling_functor, \ - required_cuda_ver_major, required_cuda_ver_minor)) - return operations - -def GeneratesGemv(math_inst, threadblock_shape, thread_shape, data_type, layout_a, layout_b, layout_c, min_cc, \ - align_a = 32, align_b = 32, align_c = 32, \ - required_cuda_ver_major = 9, required_cuda_ver_minor = 2): - element_a, element_b, element_c, element_epilogue = data_type - - A = TensorDescription(element_a, layout_a, int(align_a//DataTypeSize[element_a])) - B = TensorDescription(element_b, layout_b, int(align_b//DataTypeSize[element_b])) - C = TensorDescription(element_c, layout_c, int(align_c//DataTypeSize[element_c])) - return GemvBatchedStridedOperation(GemmKind.GemvBatchedStrided, min_cc, math_inst, threadblock_shape, thread_shape, \ - A, B, C, required_cuda_ver_major, required_cuda_ver_minor) +def GeneratesGemm( + tile, + data_type, + layout_a, + layout_b, + layout_c, + min_cc, + align_a=32, + align_b=32, + align_c=32, + required_cuda_ver_major=9, + required_cuda_ver_minor=2, +): + operations = [] + swizzling_functor = SwizzlingFunctor.Identity1 + + element_a, element_b, element_c, element_epilogue = data_type + + if tile.math_instruction.element_accumulator == DataType.s32: + epilogues = [EpilogueFunctor.LinearCombinationClamp] + else: + assert ( + tile.math_instruction.element_accumulator == DataType.f32 + or tile.math_instruction.element_accumulator == DataType.f16 + ) + epilogues = [EpilogueFunctor.LinearCombination] + + for epilogue in epilogues: + A = TensorDescription( + element_a, layout_a, int(align_a // DataTypeSize[element_a]) + ) + B = TensorDescription( + element_b, layout_b, int(align_b // DataTypeSize[element_b]) + ) + C = TensorDescription( + element_c, layout_c, int(align_c // DataTypeSize[element_c]) + ) + operations.append( + GemmOperation( + GemmKind.Gemm, + min_cc, + tile, + A, + B, + C, + element_epilogue, + epilogue, + swizzling_functor, + required_cuda_ver_major, + required_cuda_ver_minor, + ) + ) + operations.append( + GemmOperation( + GemmKind.SplitKParallel, + min_cc, + tile, + A, + B, + C, + element_epilogue, + epilogue, + swizzling_functor, + required_cuda_ver_major, + required_cuda_ver_minor, + ) + ) + return operations + + +def GeneratesGemv( + math_inst, + threadblock_shape, + thread_shape, + data_type, + layout_a, + layout_b, + layout_c, + min_cc, + align_a=32, + align_b=32, + align_c=32, + required_cuda_ver_major=9, + required_cuda_ver_minor=2, +): + element_a, element_b, element_c, element_epilogue = data_type + + A = TensorDescription(element_a, layout_a, int(align_a // DataTypeSize[element_a])) + B = TensorDescription(element_b, layout_b, int(align_b // DataTypeSize[element_b])) + C = TensorDescription(element_c, layout_c, int(align_c // DataTypeSize[element_c])) + return GemvBatchedStridedOperation( + GemmKind.GemvBatchedStrided, + min_cc, + math_inst, + threadblock_shape, + thread_shape, + A, + B, + C, + required_cuda_ver_major, + required_cuda_ver_minor, + ) + ################################################################################################### # @@ -294,10 +434,10 @@ def GeneratesGemv(math_inst, threadblock_shape, thread_shape, data_type, layout_ # class EmitGemmInstance: - ''' Responsible for emitting a CUTLASS template definition''' + """ Responsible for emitting a CUTLASS template definition""" - def __init__(self): - self.gemm_template = """ + def __init__(self): + self.gemm_template = """ // Gemm operator ${operation_name} using Operation_${operation_name} = cutlass::gemm::device::Gemm< ${element_a}, ${layout_a}, @@ -324,7 +464,7 @@ class EmitGemmInstance: ${residual} >; """ - self.gemm_complex_template = """ + self.gemm_complex_template = """ // Gemm operator ${operation_name} using Operation_${operation_name} = cutlass::gemm::device::GemmComplex< ${element_a}, ${layout_a}, @@ -351,57 +491,77 @@ class EmitGemmInstance: >; """ - def emit(self, operation): - - warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] - - epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) - - residual = '' - - values = { - 'operation_name': operation.procedural_name(), - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[operation.A.layout], - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[operation.B.layout], - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_vector_length': str(epilogue_vector_length), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], - 'stages': str(operation.tile_description.stages), - 'align_a': str(operation.A.alignment), - 'align_b': str(operation.B.alignment), - 'transform_a': ComplexTransformTag[operation.A.complex_transform], - 'transform_b': ComplexTransformTag[operation.B.complex_transform], - 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], - 'residual': residual - } - - template = self.gemm_complex_template if operation.is_complex() else self.gemm_template - - return SubstituteTemplate(template, values) + def emit(self, operation): + + warp_shape = [ + operation.tile_description.threadblock_shape[idx] + // operation.tile_description.warp_count[idx] + for idx in range(3) + ] + + epilogue_vector_length = int( + min(operation.C.alignment * DataTypeSize[operation.C.element], 128) + / DataTypeSize[operation.C.element] + ) + + residual = "" + + values = { + "operation_name": operation.procedural_name(), + "element_a": DataTypeTag[operation.A.element], + "layout_a": LayoutTag[operation.A.layout], + "element_b": DataTypeTag[operation.B.element], + "layout_b": LayoutTag[operation.B.layout], + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[operation.C.layout], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class": OpcodeClassTag[ + operation.tile_description.math_instruction.opcode_class + ], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str( + operation.tile_description.math_instruction.instruction_shape[0] + ), + "instruction_shape_n": str( + operation.tile_description.math_instruction.instruction_shape[1] + ), + "instruction_shape_k": str( + operation.tile_description.math_instruction.instruction_shape[2] + ), + "epilogue_vector_length": str(epilogue_vector_length), + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), + "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor], + "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor], + "stages": str(operation.tile_description.stages), + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + "transform_a": ComplexTransformTag[operation.A.complex_transform], + "transform_b": ComplexTransformTag[operation.B.complex_transform], + "math_operation": MathOperationTag[ + operation.tile_description.math_instruction.math_operation + ], + "residual": residual, + } + + template = ( + self.gemm_complex_template if operation.is_complex() else self.gemm_template + ) + + return SubstituteTemplate(template, values) + # class EmitGemvBatchedStridedInstance: - ''' Responsible for emitting a CUTLASS template definition''' + """ Responsible for emitting a CUTLASS template definition""" - def __init__(self): - self.template = """ + def __init__(self): + self.template = """ // Gemm operator ${operation_name} using Operation_${operation_name} = cutlass::gemm::kernel::DefaultGemv< cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, @@ -412,34 +572,35 @@ class EmitGemvBatchedStridedInstance: >; """ - def emit(self, operation): + def emit(self, operation): - values = { - 'operation_name': operation.procedural_name(), - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[operation.A.layout], - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[operation.B.layout], - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'threadblock_shape_m': str(operation.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.threadblock_shape[2]), - 'thread_shape_m': str(operation.thread_shape[0]), - 'thread_shape_n': str(operation.thread_shape[1]), - 'thread_shape_k': str(operation.thread_shape[2]), - } + values = { + "operation_name": operation.procedural_name(), + "element_a": DataTypeTag[operation.A.element], + "layout_a": LayoutTag[operation.A.layout], + "element_b": DataTypeTag[operation.B.element], + "layout_b": LayoutTag[operation.B.layout], + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[operation.C.layout], + "threadblock_shape_m": str(operation.threadblock_shape[0]), + "threadblock_shape_n": str(operation.threadblock_shape[1]), + "threadblock_shape_k": str(operation.threadblock_shape[2]), + "thread_shape_m": str(operation.thread_shape[0]), + "thread_shape_n": str(operation.thread_shape[1]), + "thread_shape_k": str(operation.thread_shape[2]), + } - return SubstituteTemplate(self.template, values) + return SubstituteTemplate(self.template, values) ################################################################################################### + class EmitSparseGemmInstance: - ''' Responsible for emitting a CUTLASS template definition''' + """ Responsible for emitting a CUTLASS template definition""" - def __init__(self): - self.gemm_template = """ + def __init__(self): + self.gemm_template = """ // Gemm operator ${operation_name} using Operation_${operation_name} = cutlass::gemm::device::SparseGemm< ${element_a}, ${layout_a}, @@ -467,60 +628,78 @@ class EmitSparseGemmInstance: >; """ - def emit(self, operation): - - warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] - - epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) - - residual = '' - - values = { - 'operation_name': operation.procedural_name(), - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[operation.A.layout], - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[operation.B.layout], - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_vector_length': str(epilogue_vector_length), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], - 'stages': str(operation.tile_description.stages), - 'align_a': str(operation.A.alignment), - 'align_b': str(operation.B.alignment), - 'transform_a': ComplexTransformTag[operation.A.complex_transform], - 'transform_b': ComplexTransformTag[operation.B.complex_transform], - 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], - 'residual': residual - } - - template = self.gemm_template - - return SubstituteTemplate(template, values) + def emit(self, operation): + + warp_shape = [ + operation.tile_description.threadblock_shape[idx] + // operation.tile_description.warp_count[idx] + for idx in range(3) + ] + + epilogue_vector_length = int( + min(operation.C.alignment * DataTypeSize[operation.C.element], 128) + / DataTypeSize[operation.C.element] + ) + + residual = "" + + values = { + "operation_name": operation.procedural_name(), + "element_a": DataTypeTag[operation.A.element], + "layout_a": LayoutTag[operation.A.layout], + "element_b": DataTypeTag[operation.B.element], + "layout_b": LayoutTag[operation.B.layout], + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[operation.C.layout], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class": OpcodeClassTag[ + operation.tile_description.math_instruction.opcode_class + ], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str( + operation.tile_description.math_instruction.instruction_shape[0] + ), + "instruction_shape_n": str( + operation.tile_description.math_instruction.instruction_shape[1] + ), + "instruction_shape_k": str( + operation.tile_description.math_instruction.instruction_shape[2] + ), + "epilogue_vector_length": str(epilogue_vector_length), + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), + "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor], + "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor], + "stages": str(operation.tile_description.stages), + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + "transform_a": ComplexTransformTag[operation.A.complex_transform], + "transform_b": ComplexTransformTag[operation.B.complex_transform], + "math_operation": MathOperationTag[ + operation.tile_description.math_instruction.math_operation + ], + "residual": residual, + } + + template = self.gemm_template + + return SubstituteTemplate(template, values) + ################################################################################################### # class EmitGemmUniversalInstance: - ''' Responsible for emitting a CUTLASS template definition''' + """ Responsible for emitting a CUTLASS template definition""" - def __init__(self): - self.gemm_template = """ + def __init__(self): + self.gemm_template = """ // Gemm operator ${operation_name} using ${operation_name}_base = typename cutlass::gemm::kernel::DefaultGemmUniversal< @@ -548,7 +727,7 @@ using ${operation_name}_base = struct ${operation_name} : public ${operation_name}_base { }; """ - self.gemm_template_interleaved = """ + self.gemm_template_interleaved = """ // Gemm operator ${operation_name} using ${operation_name}_base = typename cutlass::gemm::kernel::DefaultGemmUniversal< @@ -577,78 +756,97 @@ struct ${operation_name} : public ${operation_name}_base { }; """ - def emit(self, operation): - - threadblock_shape = operation.tile_description.threadblock_shape - warp_count = operation.tile_description.warp_count - - warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + warp_count = operation.tile_description.warp_count + + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int( + min(operation.C.alignment * DataTypeSize[operation.C.element], 128) + / DataTypeSize[operation.C.element] + ) + + transpose_layouts = { + LayoutType.ColumnMajor: LayoutType.RowMajor, + LayoutType.RowMajor: LayoutType.ColumnMajor, + } + + if ( + operation.A.layout in transpose_layouts.keys() + and operation.B.layout in transpose_layouts.keys() + and operation.C.layout in transpose_layouts.keys() + ): + + instance_layout_A = transpose_layouts[operation.A.layout] + instance_layout_B = transpose_layouts[operation.B.layout] + instance_layout_C = transpose_layouts[operation.C.layout] + + gemm_template = self.gemm_template + else: + instance_layout_A, instance_layout_B, instance_layout_C = ( + operation.A.layout, + operation.B.layout, + operation.C.layout, + ) + + gemm_template = self.gemm_template_interleaved + # + + values = { + "operation_name": operation.procedural_name(), + "element_a": DataTypeTag[operation.A.element], + "layout_a": LayoutTag[instance_layout_A], + "element_b": DataTypeTag[operation.B.element], + "layout_b": LayoutTag[instance_layout_B], + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[instance_layout_C], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class": OpcodeClassTag[ + operation.tile_description.math_instruction.opcode_class + ], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str( + operation.tile_description.math_instruction.instruction_shape[0] + ), + "instruction_shape_n": str( + operation.tile_description.math_instruction.instruction_shape[1] + ), + "instruction_shape_k": str( + operation.tile_description.math_instruction.instruction_shape[2] + ), + "epilogue_vector_length": str(epilogue_vector_length), + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), + "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor], + "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor], + "stages": str(operation.tile_description.stages), + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + "transform_a": ComplexTransformTag[operation.A.complex_transform], + "transform_b": ComplexTransformTag[operation.B.complex_transform], + "math_operation": MathOperationTag[ + operation.tile_description.math_instruction.math_operation + ], + } + + return SubstituteTemplate(gemm_template, values) - epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) - - transpose_layouts = { - LayoutType.ColumnMajor: LayoutType.RowMajor, - LayoutType.RowMajor: LayoutType.ColumnMajor - } - - if operation.A.layout in transpose_layouts.keys() and \ - operation.B.layout in transpose_layouts.keys() and \ - operation.C.layout in transpose_layouts.keys(): - - instance_layout_A = transpose_layouts[operation.A.layout] - instance_layout_B = transpose_layouts[operation.B.layout] - instance_layout_C = transpose_layouts[operation.C.layout] - - gemm_template = self.gemm_template - else: - instance_layout_A, instance_layout_B, instance_layout_C = \ - (operation.A.layout, operation.B.layout, operation.C.layout) - - gemm_template = self.gemm_template_interleaved - # - - values = { - 'operation_name': operation.procedural_name(), - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[instance_layout_A], - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[instance_layout_B], - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[instance_layout_C], - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_vector_length': str(epilogue_vector_length), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], - 'stages': str(operation.tile_description.stages), - 'align_a': str(operation.A.alignment), - 'align_b': str(operation.B.alignment), - 'transform_a': ComplexTransformTag[operation.A.complex_transform], - 'transform_b': ComplexTransformTag[operation.B.complex_transform], - 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation] - } - - return SubstituteTemplate(gemm_template, values) ################################################################################################### # class EmitGemmPlanarComplexInstance: - ''' Responsible for emitting a CUTLASS template definition''' + """ Responsible for emitting a CUTLASS template definition""" - def __init__(self): - self.template = """ + def __init__(self): + self.template = """ // Gemm operator ${operation_name} using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a}, @@ -675,54 +873,69 @@ class EmitGemmPlanarComplexInstance: public Operation_${operation_name} { }; """ - def emit(self, operation): - - warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] - - # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major - transposed_layout_A = TransposedLayout[operation.A.layout] - transposed_layout_B = TransposedLayout[operation.B.layout] - - values = { - 'operation_name': operation.procedural_name(), - 'element_a': DataTypeTag[operation.B.element], - 'layout_a': LayoutTag[transposed_layout_B], - 'transform_a': ComplexTransformTag[operation.B.complex_transform], - 'alignment_a': str(operation.B.alignment), - 'element_b': DataTypeTag[operation.A.element], - 'layout_b': LayoutTag[transposed_layout_A], - 'transform_b': ComplexTransformTag[operation.A.complex_transform], - 'alignment_b': str(operation.A.alignment), - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'alignment_c': str(operation.C.alignment), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'stages': str(operation.tile_description.stages), - 'math_operator': 'cutlass::arch::OpMultiplyAdd' - } - - return SubstituteTemplate(self.template, values) + def emit(self, operation): + + warp_shape = [ + operation.tile_description.threadblock_shape[idx] + // operation.tile_description.warp_count[idx] + for idx in range(3) + ] + + # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major + transposed_layout_A = TransposedLayout[operation.A.layout] + transposed_layout_B = TransposedLayout[operation.B.layout] + + values = { + "operation_name": operation.procedural_name(), + "element_a": DataTypeTag[operation.B.element], + "layout_a": LayoutTag[transposed_layout_B], + "transform_a": ComplexTransformTag[operation.B.complex_transform], + "alignment_a": str(operation.B.alignment), + "element_b": DataTypeTag[operation.A.element], + "layout_b": LayoutTag[transposed_layout_A], + "transform_b": ComplexTransformTag[operation.A.complex_transform], + "alignment_b": str(operation.A.alignment), + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[operation.C.layout], + "element_accumulator": DataTypeTag[ + operation.tile_description.math_instruction.element_accumulator + ], + "opcode_class": OpcodeClassTag[ + operation.tile_description.math_instruction.opcode_class + ], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str( + operation.tile_description.math_instruction.instruction_shape[0] + ), + "instruction_shape_n": str( + operation.tile_description.math_instruction.instruction_shape[1] + ), + "instruction_shape_k": str( + operation.tile_description.math_instruction.instruction_shape[2] + ), + "alignment_c": str(operation.C.alignment), + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), + "stages": str(operation.tile_description.stages), + "math_operator": "cutlass::arch::OpMultiplyAdd", + } + + return SubstituteTemplate(self.template, values) + ################################################################################################### # class EmitGemmPlanarComplexArrayInstance: - ''' Responsible for emitting a CUTLASS template definition''' + """ Responsible for emitting a CUTLASS template definition""" - def __init__(self): - self.template = """ + def __init__(self): + self.template = """ // Gemm operator ${operation_name} using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a}, @@ -748,52 +961,67 @@ class EmitGemmPlanarComplexArrayInstance: struct ${operation_name} : public Operation_${operation_name} { }; """ - def emit(self, operation): - - warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] - - # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major - transposed_layout_A = TransposedLayout[operation.A.layout] - transposed_layout_B = TransposedLayout[operation.B.layout] - - values = { - 'operation_name': operation.procedural_name(), - 'element_a': DataTypeTag[operation.B.element], - 'layout_a': LayoutTag[transposed_layout_B], - 'transform_a': ComplexTransformTag[operation.B.complex_transform], - 'alignment_a': str(operation.B.alignment), - 'element_b': DataTypeTag[operation.A.element], - 'layout_b': LayoutTag[transposed_layout_A], - 'transform_b': ComplexTransformTag[operation.A.complex_transform], - 'alignment_b': str(operation.A.alignment), - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'alignment_c': str(operation.C.alignment), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'stages': str(operation.tile_description.stages), - 'math_operator': 'cutlass::arch::OpMultiplyAdd' - } - - return SubstituteTemplate(self.template, values) + def emit(self, operation): + + warp_shape = [ + operation.tile_description.threadblock_shape[idx] + // operation.tile_description.warp_count[idx] + for idx in range(3) + ] + + # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major + transposed_layout_A = TransposedLayout[operation.A.layout] + transposed_layout_B = TransposedLayout[operation.B.layout] + + values = { + "operation_name": operation.procedural_name(), + "element_a": DataTypeTag[operation.B.element], + "layout_a": LayoutTag[transposed_layout_B], + "transform_a": ComplexTransformTag[operation.B.complex_transform], + "alignment_a": str(operation.B.alignment), + "element_b": DataTypeTag[operation.A.element], + "layout_b": LayoutTag[transposed_layout_A], + "transform_b": ComplexTransformTag[operation.A.complex_transform], + "alignment_b": str(operation.A.alignment), + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[operation.C.layout], + "element_accumulator": DataTypeTag[ + operation.tile_description.math_instruction.element_accumulator + ], + "opcode_class": OpcodeClassTag[ + operation.tile_description.math_instruction.opcode_class + ], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str( + operation.tile_description.math_instruction.instruction_shape[0] + ), + "instruction_shape_n": str( + operation.tile_description.math_instruction.instruction_shape[1] + ), + "instruction_shape_k": str( + operation.tile_description.math_instruction.instruction_shape[2] + ), + "alignment_c": str(operation.C.alignment), + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), + "stages": str(operation.tile_description.stages), + "math_operator": "cutlass::arch::OpMultiplyAdd", + } + + return SubstituteTemplate(self.template, values) + # class EmitGemmSplitKParallelInstance: - ''' Responsible for emitting a CUTLASS template definition''' + """ Responsible for emitting a CUTLASS template definition""" - def __init__(self): - self.template = """ + def __init__(self): + self.template = """ // Gemm operator ${operation_name} using Operation_${operation_name} = cutlass::gemm::device::GemmSplitKParallel< ${element_a}, ${layout_a}, @@ -828,42 +1056,60 @@ class EmitGemmSplitKParallelInstance: ${math_operation} >; """ - def emit(self, operation): - - warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] - - epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) - - values = { - 'operation_name': operation.procedural_name(), - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[operation.A.layout], - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[operation.B.layout], - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_vector_length': str(epilogue_vector_length), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - 'stages': str(operation.tile_description.stages), - 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], - 'align_a': str(operation.A.alignment), - 'align_b': str(operation.B.alignment), - } - - return SubstituteTemplate(self.template, values) + + def emit(self, operation): + + warp_shape = [ + operation.tile_description.threadblock_shape[idx] + // operation.tile_description.warp_count[idx] + for idx in range(3) + ] + + epilogue_vector_length = int( + min(operation.C.alignment * DataTypeSize[operation.C.element], 128) + / DataTypeSize[operation.C.element] + ) + + values = { + "operation_name": operation.procedural_name(), + "element_a": DataTypeTag[operation.A.element], + "layout_a": LayoutTag[operation.A.layout], + "element_b": DataTypeTag[operation.B.element], + "layout_b": LayoutTag[operation.B.layout], + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[operation.C.layout], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class": OpcodeClassTag[ + operation.tile_description.math_instruction.opcode_class + ], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str( + operation.tile_description.math_instruction.instruction_shape[0] + ), + "instruction_shape_n": str( + operation.tile_description.math_instruction.instruction_shape[1] + ), + "instruction_shape_k": str( + operation.tile_description.math_instruction.instruction_shape[2] + ), + "epilogue_vector_length": str(epilogue_vector_length), + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), + "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor], + "stages": str(operation.tile_description.stages), + "math_operation": MathOperationTag[ + operation.tile_description.math_instruction.math_operation + ], + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + } + + return SubstituteTemplate(self.template, values) ################################################################################################### @@ -875,64 +1121,67 @@ class EmitGemmSplitKParallelInstance: # ################################################################################################### + class EmitGemmConfigurationLibrary: - def __init__(self, operation_path, configuration_name): - self.configuration_name = configuration_name - self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') - - self.instance_emitter = { - GemmKind.Gemm: EmitGemmInstance, - GemmKind.Sparse: EmitSparseGemmInstance, - GemmKind.Universal: EmitGemmUniversalInstance, - GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance, - GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance - } - - self.gemm_kind_wrappers = { - GemmKind.Gemm: 'GemmOperation', - GemmKind.Sparse: 'GemmSparseOperation', - GemmKind.Universal: 'GemmUniversalOperation', - GemmKind.PlanarComplex: 'GemmPlanarComplexOperation', - GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation' - } - - self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)" - - self.instance_template = { - GemmKind.Gemm: """ + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join( + operation_path, "%s.cu" % configuration_name + ).replace("\\", "/") + + self.instance_emitter = { + GemmKind.Gemm: EmitGemmInstance, + GemmKind.Sparse: EmitSparseGemmInstance, + GemmKind.Universal: EmitGemmUniversalInstance, + GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance, + GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance, + } + + self.gemm_kind_wrappers = { + GemmKind.Gemm: "GemmOperation", + GemmKind.Sparse: "GemmSparseOperation", + GemmKind.Universal: "GemmUniversalOperation", + GemmKind.PlanarComplex: "GemmPlanarComplexOperation", + GemmKind.PlanarComplexArray: "GemmPlanarComplexArrayOperation", + } + + self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)" + + self.instance_template = { + GemmKind.Gemm: """ ${compile_guard_start} manifest.append(new ${gemm_kind}("${operation_name}")); ${compile_guard_end} """, - GemmKind.Sparse: """ + GemmKind.Sparse: """ ${compile_guard_start} manifest.append(new ${gemm_kind}("${operation_name}")); ${compile_guard_end} """, - GemmKind.Universal: """ + GemmKind.Universal: """ ${compile_guard_start} manifest.append(new ${gemm_kind}< cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> >("${operation_name}")); ${compile_guard_end} """, - GemmKind.PlanarComplex: """ + GemmKind.PlanarComplex: """ ${compile_guard_start} manifest.append(new ${gemm_kind}< cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> >("${operation_name}")); ${compile_guard_end} """, - GemmKind.PlanarComplexArray: """ + GemmKind.PlanarComplexArray: """ ${compile_guard_start} manifest.append(new ${gemm_kind}< cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> >("${operation_name}")); ${compile_guard_end} -""" - } +""", + } - self.header_template = """ + self.header_template = """ /* Generated by gemm_operation.py - Do not edit. */ @@ -950,7 +1199,7 @@ ${compile_guard_end} """ - self.initialize_function_template = """ + self.initialize_function_template = """ /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -962,7 +1211,7 @@ namespace library { void initialize_${configuration_name}(Manifest &manifest) { """ - self.epilogue_template = """ + self.epilogue_template = """ } @@ -975,66 +1224,82 @@ void initialize_${configuration_name}(Manifest &manifest) { """ - def __enter__(self): - self.configuration_file = open(self.configuration_path, "w") - self.configuration_file.write(self.header_template) - - self.instance_definitions = [] - self.instance_wrappers = [] - - self.operations = [] - return self - - def emit(self, operation): - emitter = self.instance_emitter[operation.gemm_kind]() + def __enter__(self): + self.configuration_file = open(self.configuration_path, "w") + self.configuration_file.write(self.header_template) + + self.instance_definitions = [] + self.instance_wrappers = [] + + self.operations = [] + return self + + def emit(self, operation): + emitter = self.instance_emitter[operation.gemm_kind]() + + self.operations.append(operation) + + self.instance_definitions.append(emitter.emit(operation)) + + self.instance_wrappers.append( + SubstituteTemplate( + self.instance_template[operation.gemm_kind], + { + "configuration_name": self.configuration_name, + "operation_name": operation.procedural_name(), + "gemm_kind": self.gemm_kind_wrappers[operation.gemm_kind], + "compile_guard_start": SubstituteTemplate( + self.wmma_guard_start, {"sm_number": str(operation.arch)} + ) + if operation.tile_description.math_instruction.opcode_class + == OpcodeClass.WmmaTensorOp + else "", + "compile_guard_end": "#endif" + if operation.tile_description.math_instruction.opcode_class + == OpcodeClass.WmmaTensorOp + else "", + }, + ) + ) + + def __exit__(self, exception_type, exception_value, traceback): + + # Write instance definitions in top-level namespace + for instance_definition in self.instance_definitions: + self.configuration_file.write(instance_definition) + + # Add wrapper objects within initialize() function + self.configuration_file.write( + SubstituteTemplate( + self.initialize_function_template, + {"configuration_name": self.configuration_name}, + ) + ) + + for instance_wrapper in self.instance_wrappers: + self.configuration_file.write(instance_wrapper) + + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() - self.operations.append(operation) - - self.instance_definitions.append(emitter.emit(operation)) - - self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.gemm_kind], { - 'configuration_name': self.configuration_name, - 'operation_name': operation.procedural_name(), - 'gemm_kind': self.gemm_kind_wrappers[operation.gemm_kind], - 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ - if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", - 'compile_guard_end': "#endif" \ - if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" - })) - - def __exit__(self, exception_type, exception_value, traceback): - - # Write instance definitions in top-level namespace - for instance_definition in self.instance_definitions: - self.configuration_file.write(instance_definition) - - # Add wrapper objects within initialize() function - self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { - 'configuration_name': self.configuration_name - })) - - for instance_wrapper in self.instance_wrappers: - self.configuration_file.write(instance_wrapper) - - self.configuration_file.write(self.epilogue_template) - self.configuration_file.close() ################################################################################################### ################################################################################################### + class EmitGemmSingleKernelWrapper: - def __init__(self, kernel_path, gemm_operation, short_path=False): - self.short_path = short_path - self.kernel_path = kernel_path - self.operation = gemm_operation - - instance_emitters = { - GemmKind.Gemm: EmitGemmInstance(), - GemmKind.SplitKParallel: EmitGemmSplitKParallelInstance(), - } - self.instance_emitter = instance_emitters[self.operation.gemm_kind] - - self.header_template = """ + def __init__(self, kernel_path, gemm_operation, short_path=False): + self.short_path = short_path + self.kernel_path = kernel_path + self.operation = gemm_operation + + instance_emitters = { + GemmKind.Gemm: EmitGemmInstance(), + GemmKind.SplitKParallel: EmitGemmSplitKParallelInstance(), + } + self.instance_emitter = instance_emitters[self.operation.gemm_kind] + + self.header_template = """ #if __CUDACC_VER_MAJOR__ > ${required_cuda_ver_major} || (__CUDACC_VER_MAJOR__ == ${required_cuda_ver_major} && __CUDACC_VER_MINOR__ >= ${required_cuda_ver_minor}) // ignore warning of cutlass #pragma GCC diagnostic push @@ -1049,11 +1314,11 @@ class EmitGemmSingleKernelWrapper: #include "src/cuda/cutlass/manifest.h" #include "src/cuda/cutlass/gemm_operation.h" """ - self.instance_template = """ + self.instance_template = """ ${operation_instance} """ - self.manifest_template = """ + self.manifest_template = """ namespace cutlass { namespace library { @@ -1067,53 +1332,69 @@ void initialize_${operation_name}(Manifest &manifest) { } // namespace cutlass """ - self.epilogue_template = """ + self.epilogue_template = """ #pragma GCC diagnostic pop #endif """ - # - def __enter__(self): - if self.short_path: - self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt) - GlobalCnt.cnt += 1 - else: - self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) - self.kernel_file = open(self.kernel_path, "w") - self.kernel_file.write(SubstituteTemplate(self.header_template, { - 'required_cuda_ver_major': str(self.operation.required_cuda_ver_major), - 'required_cuda_ver_minor': str(self.operation.required_cuda_ver_minor), - })) - return self - - # - def emit(self): - self.kernel_file.write(SubstituteTemplate(self.instance_template, { - 'operation_instance': self.instance_emitter.emit(self.operation), - })) - - # emit manifest helper - manifest = SubstituteTemplate(self.manifest_template, { - 'operation_name': self.operation.procedural_name(), - }) - self.kernel_file.write(manifest) - - # - def __exit__(self, exception_type, exception_value, traceback): - self.kernel_file.write(self.epilogue_template) - self.kernel_file.close() + + # + def __enter__(self): + if self.short_path: + self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt) + GlobalCnt.cnt += 1 + else: + self.kernel_path = os.path.join( + self.kernel_path, "%s.cu" % self.operation.procedural_name() + ) + self.kernel_file = open(self.kernel_path, "w") + self.kernel_file.write( + SubstituteTemplate( + self.header_template, + { + "required_cuda_ver_major": str( + self.operation.required_cuda_ver_major + ), + "required_cuda_ver_minor": str( + self.operation.required_cuda_ver_minor + ), + }, + ) + ) + return self + + # + def emit(self): + self.kernel_file.write( + SubstituteTemplate( + self.instance_template, + {"operation_instance": self.instance_emitter.emit(self.operation)}, + ) + ) + + # emit manifest helper + manifest = SubstituteTemplate( + self.manifest_template, {"operation_name": self.operation.procedural_name()} + ) + self.kernel_file.write(manifest) + + # + def __exit__(self, exception_type, exception_value, traceback): + self.kernel_file.write(self.epilogue_template) + self.kernel_file.close() ################################################################################################### ################################################################################################### + class EmitGemvSingleKernelWrapper: - def __init__(self, kernel_path, gemm_operation, wrapper_path, short_path=False): - self.kernel_path = kernel_path - self.wrapper_path = wrapper_path - self.operation = gemm_operation - self.short_path = short_path + def __init__(self, kernel_path, gemm_operation, wrapper_path, short_path=False): + self.kernel_path = kernel_path + self.wrapper_path = wrapper_path + self.operation = gemm_operation + self.short_path = short_path - self.wrapper_template = """ + self.wrapper_template = """ template void megdnn::cuda::cutlass_wrapper:: cutlass_vector_matrix_mul_batched_strided_wrapper( BatchedGemmCoord const& problem_size, @@ -1123,9 +1404,9 @@ template void megdnn::cuda::cutlass_wrapper:: cudaStream_t stream); """ - self.instance_emitter = EmitGemvBatchedStridedInstance() + self.instance_emitter = EmitGemvBatchedStridedInstance() - self.header_template = """ + self.header_template = """ #if __CUDACC_VER_MAJOR__ > ${required_cuda_ver_major} || (__CUDACC_VER_MAJOR__ == ${required_cuda_ver_major} && __CUDACC_VER_MINOR__ >= ${required_cuda_ver_minor}) // ignore warning of cutlass #pragma GCC diagnostic push @@ -1135,45 +1416,61 @@ template void megdnn::cuda::cutlass_wrapper:: #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" #include "${wrapper_path}" """ - self.instance_template = """ + self.instance_template = """ ${operation_instance} """ - self.epilogue_template = """ + self.epilogue_template = """ #pragma GCC diagnostic pop #endif """ - # - def __enter__(self): - if self.short_path: - self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt) - GlobalCnt.cnt += 1 - else: - self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) - self.kernel_file = open(self.kernel_path, "w") - self.kernel_file.write(SubstituteTemplate(self.header_template, { - 'wrapper_path': self.wrapper_path, - 'required_cuda_ver_major': str(self.operation.required_cuda_ver_major), - 'required_cuda_ver_minor': str(self.operation.required_cuda_ver_minor), - })) - return self - - # - def emit(self): - self.kernel_file.write(SubstituteTemplate(self.instance_template, { - 'operation_instance': self.instance_emitter.emit(self.operation), - })) - - # emit wrapper - wrapper = SubstituteTemplate(self.wrapper_template, { - 'operation_name': self.operation.procedural_name(), - }) - self.kernel_file.write(wrapper) - - # - def __exit__(self, exception_type, exception_value, traceback): - self.kernel_file.write(self.epilogue_template) - self.kernel_file.close() + + # + def __enter__(self): + if self.short_path: + self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt) + GlobalCnt.cnt += 1 + else: + self.kernel_path = os.path.join( + self.kernel_path, "%s.cu" % self.operation.procedural_name() + ) + self.kernel_file = open(self.kernel_path, "w") + self.kernel_file.write( + SubstituteTemplate( + self.header_template, + { + "wrapper_path": self.wrapper_path, + "required_cuda_ver_major": str( + self.operation.required_cuda_ver_major + ), + "required_cuda_ver_minor": str( + self.operation.required_cuda_ver_minor + ), + }, + ) + ) + return self + + # + def emit(self): + self.kernel_file.write( + SubstituteTemplate( + self.instance_template, + {"operation_instance": self.instance_emitter.emit(self.operation)}, + ) + ) + + # emit wrapper + wrapper = SubstituteTemplate( + self.wrapper_template, {"operation_name": self.operation.procedural_name()} + ) + self.kernel_file.write(wrapper) + + # + def __exit__(self, exception_type, exception_value, traceback): + self.kernel_file.write(self.epilogue_template) + self.kernel_file.close() + ################################################################################################### ################################################################################################### diff --git a/dnn/scripts/cutlass_generator/generator.py b/dnn/scripts/cutlass_generator/generator.py index 91881e26..6b434316 100644 --- a/dnn/scripts/cutlass_generator/generator.py +++ b/dnn/scripts/cutlass_generator/generator.py @@ -12,62 +12,84 @@ import platform from library import * from manifest import * + ################################################################################################### # -def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0): +def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch=0): - # by default, use the latest CUDA Toolkit version - cuda_version = [11, 0, 132] + # by default, use the latest CUDA Toolkit version + cuda_version = [11, 0, 132] - # Update cuda_version based on parsed string - if semantic_ver_string != '': - for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')]): - if i < len(cuda_version): - cuda_version[i] = x - else: - cuda_version.append(x) - return cuda_version >= [major, minor, patch] + # Update cuda_version based on parsed string + if semantic_ver_string != "": + for i, x in enumerate([int(x) for x in semantic_ver_string.split(".")]): + if i < len(cuda_version): + cuda_version[i] = x + else: + cuda_version.append(x) + return cuda_version >= [major, minor, patch] ################################################################################################### ################################################################################################### # -def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \ - alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ - swizzling_functor = SwizzlingFunctor.Identity8): - - if complex_transforms is None: - complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] +def CreateGemmOperator( + manifest, + layouts, + tile_descriptions, + data_type, + alignment_constraints, + complex_transforms=None, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity8, +): - element_a, element_b, element_c, element_epilogue = data_type - - operations = [] + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none)] - # by default, only generate the largest tile and largest alignment - if manifest.args.kernels == '': - tile_descriptions = [tile_descriptions[0],] - alignment_constraints = [alignment_constraints[0],] + element_a, element_b, element_c, element_epilogue = data_type - for layout in layouts: - for tile_description in tile_descriptions: - for alignment in alignment_constraints: - for complex_transform in complex_transforms: - - alignment_c = min(8, alignment) - - A = TensorDescription(element_a, layout[0], alignment, complex_transform[0]) - B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) - C = TensorDescription(element_c, layout[2], alignment_c) + operations = [] - new_operation = GemmOperation(GemmKind.Universal, tile_description.minimum_compute_capability, \ - tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) + # by default, only generate the largest tile and largest alignment + if manifest.args.kernels == "": + tile_descriptions = [tile_descriptions[0]] + alignment_constraints = [alignment_constraints[0]] - manifest.append(new_operation) - operations.append(new_operation) + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + alignment_c = min(8, alignment) + + A = TensorDescription( + element_a, layout[0], alignment, complex_transform[0] + ) + B = TensorDescription( + element_b, layout[1], alignment, complex_transform[1] + ) + C = TensorDescription(element_c, layout[2], alignment_c) + + new_operation = GemmOperation( + GemmKind.Universal, + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + epilogue_functor, + swizzling_functor, + ) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations - return operations ########################################################################################################### # ConvolutionOperator support variations @@ -82,415 +104,735 @@ def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \ # Note : Operator marked (*) are supported but not generated to keep the instantiated kernel count low ########################################################################################################### # Convolution for 2D operations -def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment, \ - conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], epilogue_functor = EpilogueFunctor.LinearCombination): - - element_a, element_b, element_c, element_epilogue = data_type - - # one exceptional case - alignment_c = min(8, alignment) - - # iterator algorithm (analytic and optimized) - iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized] - - # by default, only generate the largest tile size - if manifest.args.kernels == '': - tile_descriptions = [tile_descriptions[0],] - - operations = [] - - for tile in tile_descriptions: - for conv_kind in conv_kinds: - for iterator_algorithm in iterator_algorithms: - A = TensorDescription(element_a, layout[0], alignment) - B = TensorDescription(element_b, layout[1], alignment) - C = TensorDescription(element_c, layout[2], alignment_c) - - # unity stride only for Optimized Dgrad - if (iterator_algorithm == IteratorAlgorithm.Optimized) and (conv_kind == ConvKind.Dgrad): - new_operation = Conv2dOperation(conv_kind, iterator_algorithm, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor) - - manifest.append(new_operation) - operations.append(new_operation) - - # strided dgrad is not supported by Optimized Dgrad - if (iterator_algorithm == IteratorAlgorithm.Optimized) and (conv_kind == ConvKind.Dgrad): - continue - - # strided support for Fprop (Analytic/Optimized), Dgrad (Analytic), and Wgrad (Analytic) - new_operation = Conv2dOperation(conv_kind, iterator_algorithm, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor) - - manifest.append(new_operation) - operations.append(new_operation) - - return operations +def CreateConv2dOperator( + manifest, + layout, + tile_descriptions, + data_type, + alignment, + conv_kinds=[ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], + epilogue_functor=EpilogueFunctor.LinearCombination, +): + + element_a, element_b, element_c, element_epilogue = data_type + + # one exceptional case + alignment_c = min(8, alignment) + + # iterator algorithm (analytic and optimized) + iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized] + + # by default, only generate the largest tile size + if manifest.args.kernels == "": + tile_descriptions = [tile_descriptions[0]] + + operations = [] + + for tile in tile_descriptions: + for conv_kind in conv_kinds: + for iterator_algorithm in iterator_algorithms: + A = TensorDescription(element_a, layout[0], alignment) + B = TensorDescription(element_b, layout[1], alignment) + C = TensorDescription(element_c, layout[2], alignment_c) + + # unity stride only for Optimized Dgrad + if (iterator_algorithm == IteratorAlgorithm.Optimized) and ( + conv_kind == ConvKind.Dgrad + ): + new_operation = Conv2dOperation( + conv_kind, + iterator_algorithm, + tile.minimum_compute_capability, + tile, + A, + B, + C, + element_epilogue, + StrideSupport.Unity, + epilogue_functor, + ) + + manifest.append(new_operation) + operations.append(new_operation) + + # strided dgrad is not supported by Optimized Dgrad + if (iterator_algorithm == IteratorAlgorithm.Optimized) and ( + conv_kind == ConvKind.Dgrad + ): + continue + + # strided support for Fprop (Analytic/Optimized), Dgrad (Analytic), and Wgrad (Analytic) + new_operation = Conv2dOperation( + conv_kind, + iterator_algorithm, + tile.minimum_compute_capability, + tile, + A, + B, + C, + element_epilogue, + StrideSupport.Strided, + epilogue_functor, + ) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + ################################################################################################### ################################################################################################### + def GenerateConv2d_Simt(args): - operations = [] - - layouts = [ - (LayoutType.TensorNC4HW4, LayoutType.TensorC4RSK4), - ] - - math_instructions = [ - MathInstruction( \ - [1, 1, 4], \ - DataType.s8, DataType.s8, DataType.s32, \ - OpcodeClass.Simt, \ - MathOperation.multiply_add), - ] - - dst_layouts = [ - LayoutType.TensorNC4HW4, - LayoutType.TensorNC32HW32, - LayoutType.TensorNHWC, - LayoutType.TensorNHWC, - LayoutType.TensorNCHW - ] - - dst_types = [ - DataType.s8, - DataType.s8, - DataType.u4, - DataType.s4, - DataType.f32, - ] - - max_cc = 1024 - - for math_inst in math_instructions: - for layout in layouts: - for dst_type, dst_layout in zip(dst_types, dst_layouts): - if dst_type == DataType.s4 or dst_type == DataType.u4: - min_cc = 75 - use_special_optimization = SpecialOptimizeDesc.NoneSpecialOpt - else: - min_cc = 61 - use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity - tile_descriptions = [ - TileDescription([128, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 64, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), - ] - for tile in tile_descriptions: - if dst_layout == LayoutType.TensorNC32HW32 and tile.threadblock_shape[0] > 32: - continue - if (dst_layout == LayoutType.TensorNCHW or dst_layout == LayoutType.TensorNHWC) \ - and tile.threadblock_shape[0] > 16: - continue - operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], - dst_layout, dst_type, min_cc, 32, 32, 32, - use_special_optimization) - return operations + operations = [] + + layouts = [(LayoutType.TensorNC4HW4, LayoutType.TensorC4RSK4)] + + math_instructions = [ + MathInstruction( + [1, 1, 4], + DataType.s8, + DataType.s8, + DataType.s32, + OpcodeClass.Simt, + MathOperation.multiply_add, + ) + ] + + dst_layouts = [ + LayoutType.TensorNC4HW4, + LayoutType.TensorNC32HW32, + LayoutType.TensorNHWC, + LayoutType.TensorNHWC, + LayoutType.TensorNCHW, + ] + + dst_types = [DataType.s8, DataType.s8, DataType.u4, DataType.s4, DataType.f32] + + max_cc = 1024 + + for math_inst in math_instructions: + for layout in layouts: + for dst_type, dst_layout in zip(dst_types, dst_layouts): + if dst_type == DataType.s4 or dst_type == DataType.u4: + min_cc = 75 + use_special_optimization = SpecialOptimizeDesc.NoneSpecialOpt + else: + min_cc = 61 + use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity + tile_descriptions = [ + TileDescription( + [128, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 128, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [32, 64, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc + ), + ] + for tile in tile_descriptions: + if ( + dst_layout == LayoutType.TensorNC32HW32 + and tile.threadblock_shape[0] > 32 + ): + continue + if ( + dst_layout == LayoutType.TensorNCHW + or dst_layout == LayoutType.TensorNHWC + ) and tile.threadblock_shape[0] > 16: + continue + operations += GenerateConv2d( + ConvType.Convolution, + ConvKind.Fprop, + [tile], + layout[0], + layout[1], + dst_layout, + dst_type, + min_cc, + 32, + 32, + 32, + use_special_optimization, + ) + return operations def GenerateConv2d_TensorOp_8816(args): - operations = [] - - layouts = [ - (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32), - ] - - math_instructions = [ - MathInstruction( \ - [8, 8, 16], \ - DataType.s8, DataType.s8, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate), - ] - - dst_layouts = [ - LayoutType.TensorNC32HW32, - LayoutType.TensorNC4HW4, - ] - - dst_types = [ - DataType.s8, - DataType.s8, - ] - - use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity - - min_cc = 75 - max_cc = 1024 - - cuda_major = 10 - cuda_minor = 2 - - for math_inst in math_instructions: - for layout in layouts: - for dst_type, dst_layout in zip(dst_types, dst_layouts): - if dst_layout == LayoutType.TensorNC32HW32: - tile_descriptions = [ - TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc), - ] - operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], - dst_layout, dst_type, min_cc, 128, 128, 64, use_special_optimization, - ImplicitGemmMode.GemmTN, True, cuda_major, cuda_minor) - else: - assert dst_layout == LayoutType.TensorNC4HW4 - tile_descriptions = [ - TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc), - ] - operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], - dst_layout, dst_type, min_cc, 128, 128, 64, use_special_optimization, - ImplicitGemmMode.GemmNT, False, cuda_major, cuda_minor) - - layouts_nhwc = [ - (LayoutType.TensorNHWC, LayoutType.TensorNC4HW4, 32), - (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 64), - (LayoutType.TensorNHWC, LayoutType.TensorNC16HW16, 128), - ] - - dst_layouts_nhwc = [ - LayoutType.TensorNHWC, - ] - - for math_inst in math_instructions: - for layout in layouts_nhwc: - for dst_layout in dst_layouts_nhwc: - dst_type = math_inst.element_b - tile_descriptions = [ - TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), - ] - for tile in tile_descriptions: - dst_align = 32 if tile.threadblock_shape[1] == 16 else 64 - operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], dst_layout, - dst_type, min_cc, layout[2], layout[2], dst_align, use_special_optimization, - ImplicitGemmMode.GemmTN, False, cuda_major, cuda_minor) - if tile.threadblock_shape[1] == 16 or tile.threadblock_shape[1] == 32: - operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], dst_layout, - dst_type, min_cc, layout[2], layout[2], dst_align, use_special_optimization, - ImplicitGemmMode.GemmTN, True, cuda_major, cuda_minor) - - out_dtypes = [DataType.s4, DataType.u4, DataType.f32] - - #INT8x8x4 and INT8x8x32 - for math_inst in math_instructions: - for layout in layouts_nhwc: - for dst_layout in dst_layouts_nhwc: - for out_dtype in out_dtypes: - tile_descriptions = [ - TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), - ] - for tile in tile_descriptions: - dst_align = 4 * DataTypeSize[out_dtype] if tile.threadblock_shape[1] == 16 or out_dtype == DataType.f32 \ - else 8 * DataTypeSize[out_dtype] - operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], dst_layout, - out_dtype, min_cc, layout[2], layout[2], dst_align, use_special_optimization, - ImplicitGemmMode.GemmTN, False, cuda_major, cuda_minor) - if tile.threadblock_shape[1] == 16 or (tile.threadblock_shape[1] == 32 and out_dtype != DataType.f32): - operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], dst_layout, - out_dtype, min_cc, layout[2], layout[2], dst_align, use_special_optimization, - ImplicitGemmMode.GemmTN, True, cuda_major, cuda_minor) - - return operations + operations = [] + + layouts = [(LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32)] + + math_instructions = [ + MathInstruction( + [8, 8, 16], + DataType.s8, + DataType.s8, + DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_saturate, + ) + ] + + dst_layouts = [LayoutType.TensorNC32HW32, LayoutType.TensorNC4HW4] + + dst_types = [DataType.s8, DataType.s8] + + use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity + + min_cc = 75 + max_cc = 1024 + + cuda_major = 10 + cuda_minor = 2 + + for math_inst in math_instructions: + for layout in layouts: + for dst_type, dst_layout in zip(dst_types, dst_layouts): + if dst_layout == LayoutType.TensorNC32HW32: + tile_descriptions = [ + TileDescription( + [128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc + ), + ] + operations += GenerateConv2d( + ConvType.Convolution, + ConvKind.Fprop, + tile_descriptions, + layout[0], + layout[1], + dst_layout, + dst_type, + min_cc, + 128, + 128, + 64, + use_special_optimization, + ImplicitGemmMode.GemmTN, + True, + cuda_major, + cuda_minor, + ) + else: + assert dst_layout == LayoutType.TensorNC4HW4 + tile_descriptions = [ + TileDescription( + [64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc + ), + ] + operations += GenerateConv2d( + ConvType.Convolution, + ConvKind.Fprop, + tile_descriptions, + layout[0], + layout[1], + dst_layout, + dst_type, + min_cc, + 128, + 128, + 64, + use_special_optimization, + ImplicitGemmMode.GemmNT, + False, + cuda_major, + cuda_minor, + ) + + layouts_nhwc = [ + (LayoutType.TensorNHWC, LayoutType.TensorNC4HW4, 32), + (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 64), + (LayoutType.TensorNHWC, LayoutType.TensorNC16HW16, 128), + ] + + dst_layouts_nhwc = [LayoutType.TensorNHWC] + + for math_inst in math_instructions: + for layout in layouts_nhwc: + for dst_layout in dst_layouts_nhwc: + dst_type = math_inst.element_b + tile_descriptions = [ + TileDescription( + [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc + ), + ] + for tile in tile_descriptions: + dst_align = 32 if tile.threadblock_shape[1] == 16 else 64 + operations += GenerateConv2d( + ConvType.Convolution, + ConvKind.Fprop, + [tile], + layout[0], + layout[1], + dst_layout, + dst_type, + min_cc, + layout[2], + layout[2], + dst_align, + use_special_optimization, + ImplicitGemmMode.GemmTN, + False, + cuda_major, + cuda_minor, + ) + if ( + tile.threadblock_shape[1] == 16 + or tile.threadblock_shape[1] == 32 + ): + operations += GenerateConv2d( + ConvType.Convolution, + ConvKind.Fprop, + [tile], + layout[0], + layout[1], + dst_layout, + dst_type, + min_cc, + layout[2], + layout[2], + dst_align, + use_special_optimization, + ImplicitGemmMode.GemmTN, + True, + cuda_major, + cuda_minor, + ) + + out_dtypes = [DataType.s4, DataType.u4, DataType.f32] + + # INT8x8x4 and INT8x8x32 + for math_inst in math_instructions: + for layout in layouts_nhwc: + for dst_layout in dst_layouts_nhwc: + for out_dtype in out_dtypes: + tile_descriptions = [ + TileDescription( + [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc + ), + ] + for tile in tile_descriptions: + dst_align = ( + 4 * DataTypeSize[out_dtype] + if tile.threadblock_shape[1] == 16 + or out_dtype == DataType.f32 + else 8 * DataTypeSize[out_dtype] + ) + operations += GenerateConv2d( + ConvType.Convolution, + ConvKind.Fprop, + [tile], + layout[0], + layout[1], + dst_layout, + out_dtype, + min_cc, + layout[2], + layout[2], + dst_align, + use_special_optimization, + ImplicitGemmMode.GemmTN, + False, + cuda_major, + cuda_minor, + ) + if tile.threadblock_shape[1] == 16 or ( + tile.threadblock_shape[1] == 32 + and out_dtype != DataType.f32 + ): + operations += GenerateConv2d( + ConvType.Convolution, + ConvKind.Fprop, + [tile], + layout[0], + layout[1], + dst_layout, + out_dtype, + min_cc, + layout[2], + layout[2], + dst_align, + use_special_optimization, + ImplicitGemmMode.GemmTN, + True, + cuda_major, + cuda_minor, + ) + + return operations + def GenerateConv2d_TensorOp_8832(args): - operations = [] - - layouts = [ - (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64), - ] - - math_instructions = [ - MathInstruction( \ - [8, 8, 32], \ - DataType.s4, DataType.s4, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate), \ - MathInstruction( \ - [8, 8, 32], \ - DataType.s4, DataType.u4, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate) - ] - - dst_layouts = [ - LayoutType.TensorNC64HW64, - ] - - use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity - - min_cc = 75 - max_cc = 1024 - - cuda_major = 10 - cuda_minor = 2 - - for math_inst in math_instructions: - for layout in layouts: - for dst_layout in dst_layouts: - dst_type = math_inst.element_b - tile_descriptions = [ - TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 128], 2, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc), - ] - operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], - dst_layout, dst_type, min_cc, 128, 128, 64, use_special_optimization, - ImplicitGemmMode.GemmTN, True, cuda_major, cuda_minor) - - layouts_nhwc = [ - (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32), - (LayoutType.TensorNHWC, LayoutType.TensorNC16HW16, 64), - (LayoutType.TensorNHWC, LayoutType.TensorNC32HW32, 128), - ] - - dst_layouts_nhwc = [ - LayoutType.TensorNHWC, - ] - - for math_inst in math_instructions: - for layout in layouts_nhwc: - for dst_layout in dst_layouts_nhwc: - dst_type = math_inst.element_b - tile_descriptions = [ - TileDescription([128, 16, 64], 2, [1, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc), - ] - for tile in tile_descriptions: - dst_align = 16 if tile.threadblock_shape[1] == 16 else 32 - operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], dst_layout, - dst_type, min_cc, layout[2], layout[2], dst_align, use_special_optimization, - ImplicitGemmMode.GemmTN, False, cuda_major, cuda_minor) - if tile.threadblock_shape[1] == 32 or tile.threadblock_shape[1] == 64: - dst_align = 32 if tile.threadblock_shape[1] == 32 else 64 - operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], dst_layout, - dst_type, min_cc, layout[2], layout[2], dst_align, use_special_optimization, - ImplicitGemmMode.GemmTN, True, cuda_major, cuda_minor) - # INT4x4x8 - for math_inst in math_instructions: - for layout in layouts_nhwc: - for dst_layout in dst_layouts_nhwc: - tile_descriptions = [ - TileDescription([128, 16, 64], 2, [1, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc), - ] - for tile in tile_descriptions: - dst_align = 32 if tile.threadblock_shape[1] == 16 else 64 - operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], dst_layout, - DataType.s8, min_cc, layout[2], layout[2], dst_align, use_special_optimization, - ImplicitGemmMode.GemmTN, False, cuda_major, cuda_minor) - if tile.threadblock_shape[1] == 32 or tile.threadblock_shape[1] == 64: - dst_align = 64 if tile.threadblock_shape[1] == 32 else 128 - operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], dst_layout, - DataType.s8, min_cc, layout[2], layout[2], dst_align, use_special_optimization, - ImplicitGemmMode.GemmTN, True, cuda_major, cuda_minor) - - return operations + operations = [] + + layouts = [(LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64)] + + math_instructions = [ + MathInstruction( + [8, 8, 32], + DataType.s4, + DataType.s4, + DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_saturate, + ), + MathInstruction( + [8, 8, 32], + DataType.s4, + DataType.u4, + DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_saturate, + ), + ] + + dst_layouts = [LayoutType.TensorNC64HW64] + + use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity + + min_cc = 75 + max_cc = 1024 + + cuda_major = 10 + cuda_minor = 2 + + for math_inst in math_instructions: + for layout in layouts: + for dst_layout in dst_layouts: + dst_type = math_inst.element_b + tile_descriptions = [ + TileDescription( + [128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 64, 128], 2, [2, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc + ), + ] + operations += GenerateConv2d( + ConvType.Convolution, + ConvKind.Fprop, + tile_descriptions, + layout[0], + layout[1], + dst_layout, + dst_type, + min_cc, + 128, + 128, + 64, + use_special_optimization, + ImplicitGemmMode.GemmTN, + True, + cuda_major, + cuda_minor, + ) + + layouts_nhwc = [ + (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32), + (LayoutType.TensorNHWC, LayoutType.TensorNC16HW16, 64), + (LayoutType.TensorNHWC, LayoutType.TensorNC32HW32, 128), + ] + + dst_layouts_nhwc = [LayoutType.TensorNHWC] + + for math_inst in math_instructions: + for layout in layouts_nhwc: + for dst_layout in dst_layouts_nhwc: + dst_type = math_inst.element_b + tile_descriptions = [ + TileDescription( + [128, 16, 64], 2, [1, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc + ), + ] + for tile in tile_descriptions: + dst_align = 16 if tile.threadblock_shape[1] == 16 else 32 + operations += GenerateConv2d( + ConvType.Convolution, + ConvKind.Fprop, + [tile], + layout[0], + layout[1], + dst_layout, + dst_type, + min_cc, + layout[2], + layout[2], + dst_align, + use_special_optimization, + ImplicitGemmMode.GemmTN, + False, + cuda_major, + cuda_minor, + ) + if ( + tile.threadblock_shape[1] == 32 + or tile.threadblock_shape[1] == 64 + ): + dst_align = 32 if tile.threadblock_shape[1] == 32 else 64 + operations += GenerateConv2d( + ConvType.Convolution, + ConvKind.Fprop, + [tile], + layout[0], + layout[1], + dst_layout, + dst_type, + min_cc, + layout[2], + layout[2], + dst_align, + use_special_optimization, + ImplicitGemmMode.GemmTN, + True, + cuda_major, + cuda_minor, + ) + # INT4x4x8 + for math_inst in math_instructions: + for layout in layouts_nhwc: + for dst_layout in dst_layouts_nhwc: + tile_descriptions = [ + TileDescription( + [128, 16, 64], 2, [1, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc + ), + ] + for tile in tile_descriptions: + dst_align = 32 if tile.threadblock_shape[1] == 16 else 64 + operations += GenerateConv2d( + ConvType.Convolution, + ConvKind.Fprop, + [tile], + layout[0], + layout[1], + dst_layout, + DataType.s8, + min_cc, + layout[2], + layout[2], + dst_align, + use_special_optimization, + ImplicitGemmMode.GemmTN, + False, + cuda_major, + cuda_minor, + ) + if ( + tile.threadblock_shape[1] == 32 + or tile.threadblock_shape[1] == 64 + ): + dst_align = 64 if tile.threadblock_shape[1] == 32 else 128 + operations += GenerateConv2d( + ConvType.Convolution, + ConvKind.Fprop, + [tile], + layout[0], + layout[1], + dst_layout, + DataType.s8, + min_cc, + layout[2], + layout[2], + dst_align, + use_special_optimization, + ImplicitGemmMode.GemmTN, + True, + cuda_major, + cuda_minor, + ) + + return operations + def GenerateDeconv_Simt(args): - operations = [] - - layouts = [ - (LayoutType.TensorNC4HW4, LayoutType.TensorK4RSC4), - ] - - math_instructions = [ - MathInstruction( \ - [1, 1, 4], \ - DataType.s8, DataType.s8, DataType.s32, \ - OpcodeClass.Simt, \ - MathOperation.multiply_add), - ] - - dst_layouts = [ - LayoutType.TensorNC4HW4, - ] - - dst_types = [ - DataType.s8, - ] - - use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling - - min_cc = 61 - max_cc = 1024 - - for math_inst in math_instructions: - for layout in layouts: - for dst_type, dst_layout in zip(dst_types, dst_layouts): - tile_descriptions = [ - TileDescription([32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc), - TileDescription([16, 128, 16], 2, [1, 2, 1], math_inst, min_cc, max_cc), - TileDescription([16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc), - TileDescription([16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), - ] - operations += GenerateConv2d(ConvKind.Dgrad, tile_descriptions, layout[0], layout[1], - dst_layout, dst_type, min_cc, 32, 32, 32, - use_special_optimization) - return operations + operations = [] + + layouts = [(LayoutType.TensorNC4HW4, LayoutType.TensorK4RSC4)] + + math_instructions = [ + MathInstruction( + [1, 1, 4], + DataType.s8, + DataType.s8, + DataType.s32, + OpcodeClass.Simt, + MathOperation.multiply_add, + ) + ] + + dst_layouts = [LayoutType.TensorNC4HW4] + + dst_types = [DataType.s8] + + use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling + + min_cc = 61 + max_cc = 1024 + + for math_inst in math_instructions: + for layout in layouts: + for dst_type, dst_layout in zip(dst_types, dst_layouts): + tile_descriptions = [ + TileDescription( + [32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [16, 128, 16], 2, [1, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc + ), + ] + operations += GenerateConv2d( + ConvType.Convolution, + ConvKind.Dgrad, + tile_descriptions, + layout[0], + layout[1], + dst_layout, + dst_type, + min_cc, + 32, + 32, + 32, + use_special_optimization, + ) + return operations + def GenerateDeconv_TensorOp_8816(args): - operations = [] - - layouts = [ - (LayoutType.TensorNHWC, LayoutType.TensorCK4RS4, 32), - (LayoutType.TensorNHWC, LayoutType.TensorCK8RS8, 64), - (LayoutType.TensorNHWC, LayoutType.TensorCK16RS16, 128), - ] - - math_instructions = [ - MathInstruction( \ - [8, 8, 16], \ - DataType.s8, DataType.s8, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate), - ] - - dst_layouts = [ - LayoutType.TensorNHWC, - ] - - dst_types = [ - DataType.s8, - ] - - use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling - - min_cc = 75 - max_cc = 1024 - - cuda_major = 10 - cuda_minor = 2 - - for math_inst in math_instructions: - for layout in layouts: - for dst_type, dst_layout in zip(dst_types, dst_layouts): - tile_descriptions = [ - TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), - ] - for tile in tile_descriptions: - dst_align = 32 if tile.threadblock_shape[1] == 16 else 64 - operations += GenerateConv2d(ConvKind.Dgrad, [tile], layout[0], layout[1], dst_layout, dst_type, - min_cc, layout[2], layout[2], dst_align, use_special_optimization, - ImplicitGemmMode.GemmTN, False, cuda_major, cuda_minor) - return operations + operations = [] + + layouts = [ + (LayoutType.TensorNHWC, LayoutType.TensorCK4RS4, 32), + (LayoutType.TensorNHWC, LayoutType.TensorCK8RS8, 64), + (LayoutType.TensorNHWC, LayoutType.TensorCK16RS16, 128), + ] + + math_instructions = [ + MathInstruction( + [8, 8, 16], + DataType.s8, + DataType.s8, + DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_saturate, + ) + ] + + dst_layouts = [LayoutType.TensorNHWC] + + dst_types = [DataType.s8] + + use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling + + min_cc = 75 + max_cc = 1024 + + cuda_major = 10 + cuda_minor = 2 + + for math_inst in math_instructions: + for layout in layouts: + for dst_type, dst_layout in zip(dst_types, dst_layouts): + tile_descriptions = [ + TileDescription( + [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc + ), + ] + for tile in tile_descriptions: + dst_align = 32 if tile.threadblock_shape[1] == 16 else 64 + operations += GenerateConv2d( + ConvType.Convolution, + ConvKind.Dgrad, + [tile], + layout[0], + layout[1], + dst_layout, + dst_type, + min_cc, + layout[2], + layout[2], + dst_align, + use_special_optimization, + ImplicitGemmMode.GemmTN, + False, + cuda_major, + cuda_minor, + ) + return operations + ################################################################################ # parameters @@ -507,429 +849,843 @@ warpsPerThreadblockMax = 16 warpShapeEdges = [8, 16, 32, 64, 128, 256] warpShapeRatio = 4 -warpShapeMax = 64*64 -warpShapeMin = 8*8 +warpShapeMax = 64 * 64 +warpShapeMin = 8 * 8 threadblockEdgeMax = 256 # char, type bits/elem, max tile, L0 threadblock tiles precisions = { - "c" : [ "cutlass::complex", 64, 64*128, [ [ 64, 128], [ 64, 32] ] ], - "d" : [ "double", 64, 64*64, [ [ 64, 64], [ 32, 32] ] ], - "h" : [ "cutlass::half_t", 16, 128*256, [ [256, 128], [ 64, 128], [ 64, 32] ] ], - "i" : [ "int", 32, 128*128, [ [128, 64], [ 16, 32] ] ], - "s" : [ "float", 32, 128*128, [ [128, 256], [128, 128], [ 64, 64] ] ], - "z" : [ "cutlass::complex", 128, 64*64, [ [ 32, 64], [ 16, 32] ] ], + "c": ["cutlass::complex", 64, 64 * 128, [[64, 128], [64, 32]]], + "d": ["double", 64, 64 * 64, [[64, 64], [32, 32]]], + "h": ["cutlass::half_t", 16, 128 * 256, [[256, 128], [64, 128], [64, 32]]], + "i": ["int", 32, 128 * 128, [[128, 64], [16, 32]]], + "s": ["float", 32, 128 * 128, [[128, 256], [128, 128], [64, 64]]], + "z": ["cutlass::complex", 128, 64 * 64, [[32, 64], [16, 32]]], } # L1 will have a single kernel for every unique shape # L2 will have everything else def GenerateGemm_Simt(args): - ################################################################################ - # warps per threadblock - ################################################################################ - warpsPerThreadblocks = [] - for warpsPerThreadblock0 in warpsPerThreadblockEdge: - for warpsPerThreadblock1 in warpsPerThreadblockEdge: - if warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio \ - and warpsPerThreadblock1 / warpsPerThreadblock0 <= warpsPerThreadblockRatio \ - and warpsPerThreadblock0 * warpsPerThreadblock1 <= warpsPerThreadblockMax: - warpsPerThreadblocks.append([warpsPerThreadblock0, - warpsPerThreadblock1]) - - ################################################################################ - # warp shapes - ################################################################################ - warpNumThreads = 32 - warpShapes = [] - for warp0 in warpShapeEdges: - for warp1 in warpShapeEdges: - if warp0 / warp1 <= warpShapeRatio \ - and warp1 / warp0 <= warpShapeRatio \ - and warp0 * warp1 <= warpShapeMax \ - and warp0*warp1 > warpShapeMin: - warpShapes.append([warp0, warp1]) - - # sgemm - precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions["s"] - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt - ] - - math_instructions = [ - MathInstruction( \ - [1, 1, 1], \ - DataType.f32, DataType.f32, DataType.f32, \ - OpcodeClass.Simt, \ - MathOperation.multiply_add), - ] - - min_cc = 50 - max_cc = 1024 - - operations = [] - for math_inst in math_instructions: - for layout in layouts: - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - tile_descriptions = [ - TileDescription([64, 256, 8], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 256, 8], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 32, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 8, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 16, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 16, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), - ] - for warpsPerThreadblock in warpsPerThreadblocks: - for warpShape in warpShapes: - warpThreadsM = 0 - if warpShape[0] > warpShape[1]: - warpThreadsM = 8 - else: - warpThreadsM = 4 - warpThreadsN = warpNumThreads / warpThreadsM - - # skip shapes with conflicting rectangularity - # they are unlikely to be fastest - blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1] - blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1] - warpG = warpShape[0] > warpShape[1] - warpL = warpShape[0] < warpShape[1] - - blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1]*2 - blockL2 = warpsPerThreadblock[0]*2 < warpsPerThreadblock[1] - warpG2 = warpShape[0] > warpShape[1]*2 - warpL2 = warpShape[0]*2 < warpShape[1] - - if blockG2 and warpL: continue - if blockL2 and warpG: continue - if warpG2 and blockL: continue - if warpL2 and blockG: continue - - # check threadblock ratios and max - threadblockTile = [warpShape[0]*warpsPerThreadblock[0], - warpShape[1]*warpsPerThreadblock[1]] - if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements: continue - if threadblockTile[0] > threadblockEdgeMax: continue - if threadblockTile[1] > threadblockEdgeMax: continue - totalThreads = warpNumThreads*warpsPerThreadblock[0]*warpsPerThreadblock[1] - - # calculate unroll - # ensure that every iteration at least a full load of A,B are done - unrollMin = 8 - unrollMin0 = totalThreads // threadblockTile[0] - unrollMin1 = totalThreads // threadblockTile[1] - unroll = max(unrollMin, unrollMin0, unrollMin1) - - threadTileM = warpShape[0] // warpThreadsM - threadTileN = warpShape[1] // warpThreadsN - if threadTileM < 2 or threadTileN < 2: continue - if threadTileM*threadTileN*precisionBits > 8*8*32: continue - - # epilogue currently only supports N < WarpNumThreads - if threadblockTile[1] < warpNumThreads: continue - - # limit smem - smemBitsA = threadblockTile[0]*unroll*2*precisionBits - smemBitsB = threadblockTile[1]*unroll*2*precisionBits - smemKBytes = (smemBitsA+smemBitsB)/8/1024 - if (smemKBytes > 48): continue - - tile = TileDescription([threadblockTile[0], threadblockTile[1], unroll], \ - 2, \ - [threadblockTile[0]//warpShape[0], threadblockTile[1]//warpShape[1], 1], \ - math_inst, min_cc, max_cc) - - def filter(t: TileDescription) -> bool: - nonlocal tile - return t.threadblock_shape[0] == tile.threadblock_shape[0] and \ - t.threadblock_shape[1] == tile.threadblock_shape[1] and \ - t.threadblock_shape[2] == tile.threadblock_shape[2] and \ - t.warp_count[0] == tile.warp_count[0] and \ - t.warp_count[1] == tile.warp_count[1] and \ - t.warp_count[2] == tile.warp_count[2] and \ - t.stages == tile.stages - if not any(t for t in tile_descriptions if filter(t)): continue - - operations += GeneratesGemm(tile, data_type, layout[0], layout[1], layout[2], min_cc) - return operations + ################################################################################ + # warps per threadblock + ################################################################################ + warpsPerThreadblocks = [] + for warpsPerThreadblock0 in warpsPerThreadblockEdge: + for warpsPerThreadblock1 in warpsPerThreadblockEdge: + if ( + warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio + and warpsPerThreadblock1 / warpsPerThreadblock0 + <= warpsPerThreadblockRatio + and warpsPerThreadblock0 * warpsPerThreadblock1 + <= warpsPerThreadblockMax + ): + warpsPerThreadblocks.append( + [warpsPerThreadblock0, warpsPerThreadblock1] + ) + + ################################################################################ + # warp shapes + ################################################################################ + warpNumThreads = 32 + warpShapes = [] + for warp0 in warpShapeEdges: + for warp1 in warpShapeEdges: + if ( + warp0 / warp1 <= warpShapeRatio + and warp1 / warp0 <= warpShapeRatio + and warp0 * warp1 <= warpShapeMax + and warp0 * warp1 > warpShapeMin + ): + warpShapes.append([warp0, warp1]) + + # sgemm + precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[ + "s" + ] + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt + ] + + math_instructions = [ + MathInstruction( + [1, 1, 1], + DataType.f32, + DataType.f32, + DataType.f32, + OpcodeClass.Simt, + MathOperation.multiply_add, + ) + ] + + min_cc = 50 + max_cc = 1024 + + operations = [] + for math_inst in math_instructions: + for layout in layouts: + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + tile_descriptions = [ + TileDescription([64, 256, 8], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 256, 8], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), + TileDescription([8, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), + TileDescription([16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), + TileDescription([16, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), + ] + for warpsPerThreadblock in warpsPerThreadblocks: + for warpShape in warpShapes: + warpThreadsM = 0 + if warpShape[0] > warpShape[1]: + warpThreadsM = 8 + else: + warpThreadsM = 4 + warpThreadsN = warpNumThreads / warpThreadsM + + # skip shapes with conflicting rectangularity + # they are unlikely to be fastest + blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1] + blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1] + warpG = warpShape[0] > warpShape[1] + warpL = warpShape[0] < warpShape[1] + + blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1] * 2 + blockL2 = warpsPerThreadblock[0] * 2 < warpsPerThreadblock[1] + warpG2 = warpShape[0] > warpShape[1] * 2 + warpL2 = warpShape[0] * 2 < warpShape[1] + + if blockG2 and warpL: + continue + if blockL2 and warpG: + continue + if warpG2 and blockL: + continue + if warpL2 and blockG: + continue + + # check threadblock ratios and max + threadblockTile = [ + warpShape[0] * warpsPerThreadblock[0], + warpShape[1] * warpsPerThreadblock[1], + ] + if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements: + continue + if threadblockTile[0] > threadblockEdgeMax: + continue + if threadblockTile[1] > threadblockEdgeMax: + continue + totalThreads = ( + warpNumThreads * warpsPerThreadblock[0] * warpsPerThreadblock[1] + ) + + # calculate unroll + # ensure that every iteration at least a full load of A,B are done + unrollMin = 8 + unrollMin0 = totalThreads // threadblockTile[0] + unrollMin1 = totalThreads // threadblockTile[1] + unroll = max(unrollMin, unrollMin0, unrollMin1) + + threadTileM = warpShape[0] // warpThreadsM + threadTileN = warpShape[1] // warpThreadsN + if threadTileM < 2 or threadTileN < 2: + continue + if threadTileM * threadTileN * precisionBits > 8 * 8 * 32: + continue + + # epilogue currently only supports N < WarpNumThreads + if threadblockTile[1] < warpNumThreads: + continue + + # limit smem + smemBitsA = threadblockTile[0] * unroll * 2 * precisionBits + smemBitsB = threadblockTile[1] * unroll * 2 * precisionBits + smemKBytes = (smemBitsA + smemBitsB) / 8 / 1024 + if smemKBytes > 48: + continue + + tile = TileDescription( + [threadblockTile[0], threadblockTile[1], unroll], + 2, + [ + threadblockTile[0] // warpShape[0], + threadblockTile[1] // warpShape[1], + 1, + ], + math_inst, + min_cc, + max_cc, + ) + + def filter(t: TileDescription) -> bool: + nonlocal tile + return ( + t.threadblock_shape[0] == tile.threadblock_shape[0] + and t.threadblock_shape[1] == tile.threadblock_shape[1] + and t.threadblock_shape[2] == tile.threadblock_shape[2] + and t.warp_count[0] == tile.warp_count[0] + and t.warp_count[1] == tile.warp_count[1] + and t.warp_count[2] == tile.warp_count[2] + and t.stages == tile.stages + ) + + if not any(t for t in tile_descriptions if filter(t)): + continue + + operations += GeneratesGemm( + tile, data_type, layout[0], layout[1], layout[2], min_cc + ) + return operations + + +def GenerateDwconv2dFprop_Simt(args): + ################################################################################ + # warps per threadblock + ################################################################################ + warpsPerThreadblocks = [] + for warpsPerThreadblock0 in warpsPerThreadblockEdge: + for warpsPerThreadblock1 in warpsPerThreadblockEdge: + if ( + warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio + and warpsPerThreadblock1 / warpsPerThreadblock0 + <= warpsPerThreadblockRatio + and warpsPerThreadblock0 * warpsPerThreadblock1 + <= warpsPerThreadblockMax + ): + warpsPerThreadblocks.append( + [warpsPerThreadblock0, warpsPerThreadblock1] + ) + + ################################################################################ + # warp shapes + ################################################################################ + warpNumThreads = 32 + warpShapes = [] + for warp0 in warpShapeEdges: + for warp1 in warpShapeEdges: + if ( + warp0 / warp1 <= warpShapeRatio + and warp1 / warp0 <= warpShapeRatio + and warp0 * warp1 <= warpShapeMax + and warp0 * warp1 > warpShapeMin + ): + warpShapes.append([warp0, warp1]) + + # sgemm + precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[ + "s" + ] + + layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)] + + math_instructions = [ + MathInstruction( + [1, 1, 1], + DataType.f32, + DataType.f32, + DataType.f32, + OpcodeClass.Simt, + MathOperation.multiply_add, + ) + ] + + min_cc = 50 + max_cc = 1024 + + dst_layouts = [LayoutType.TensorNCHW] + + dst_types = [DataType.f32] + + alignment_constraints = [128, 32] + + operations = [] + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), + ] + for warpsPerThreadblock in warpsPerThreadblocks: + for warpShape in warpShapes: + warpThreadsM = 0 + if warpShape[0] > warpShape[1]: + warpThreadsM = 8 + else: + warpThreadsM = 4 + warpThreadsN = warpNumThreads / warpThreadsM + + # skip shapes with conflicting rectangularity + # they are unlikely to be fastest + blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1] + blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1] + warpG = warpShape[0] > warpShape[1] + warpL = warpShape[0] < warpShape[1] + + blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1] * 2 + blockL2 = warpsPerThreadblock[0] * 2 < warpsPerThreadblock[1] + warpG2 = warpShape[0] > warpShape[1] * 2 + warpL2 = warpShape[0] * 2 < warpShape[1] + + if blockG2 and warpL: + continue + if blockL2 and warpG: + continue + if warpG2 and blockL: + continue + if warpL2 and blockG: + continue + + # check threadblock ratios and max + threadblockTile = [ + warpShape[0] * warpsPerThreadblock[0], + warpShape[1] * warpsPerThreadblock[1], + ] + if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements: + continue + if threadblockTile[0] > threadblockEdgeMax: + continue + if threadblockTile[1] > threadblockEdgeMax: + continue + totalThreads = ( + warpNumThreads * warpsPerThreadblock[0] * warpsPerThreadblock[1] + ) + + # calculate unroll + # ensure that every iteration at least a full load of A,B are done + unrollMin = 8 + unrollMin0 = totalThreads // threadblockTile[0] + unrollMin1 = totalThreads // threadblockTile[1] + unroll = max(unrollMin, unrollMin0, unrollMin1) + + threadTileM = warpShape[0] // warpThreadsM + threadTileN = warpShape[1] // warpThreadsN + if threadTileM < 2 or threadTileN < 2: + continue + if threadTileM * threadTileN * precisionBits > 8 * 8 * 32: + continue + + # epilogue currently only supports N < WarpNumThreads + if threadblockTile[1] < warpNumThreads: + continue + + # limit smem + smemBitsA = threadblockTile[0] * unroll * 2 * precisionBits + smemBitsB = threadblockTile[1] * unroll * 2 * precisionBits + smemKBytes = (smemBitsA + smemBitsB) / 8 / 1024 + if smemKBytes > 48: + continue + + tile = TileDescription( + [threadblockTile[0], threadblockTile[1], unroll], + 2, + [ + threadblockTile[0] // warpShape[0], + threadblockTile[1] // warpShape[1], + 1, + ], + math_inst, + min_cc, + max_cc, + ) + + def filter(t: TileDescription) -> bool: + nonlocal tile + return ( + t.threadblock_shape[0] == tile.threadblock_shape[0] + and t.threadblock_shape[1] == tile.threadblock_shape[1] + and t.threadblock_shape[2] == tile.threadblock_shape[2] + and t.warp_count[0] == tile.warp_count[0] + and t.warp_count[1] == tile.warp_count[1] + and t.warp_count[2] == tile.warp_count[2] + and t.stages == tile.stages + ) + + if not any(t for t in tile_descriptions if filter(t)): + continue + + for layout in layouts: + for dst_type, dst_layout in zip(dst_types, dst_layouts): + for alignment_src in alignment_constraints: + operations += GenerateConv2d( + ConvType.DepthwiseConvolution, + ConvKind.Fprop, + [tile], + layout[0], + layout[1], + dst_layout, + dst_type, + min_cc, + alignment_src, + 32, + 32, + SpecialOptimizeDesc.NoneSpecialOpt, + ImplicitGemmMode.GemmTN, + ) + return operations + + +# +def GenerateDwconv2dFprop_TensorOp_884(args): + layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)] + + math_instructions = [ + MathInstruction( + [8, 8, 4], + DataType.f16, + DataType.f16, + DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ), + MathInstruction( + [8, 8, 4], + DataType.f16, + DataType.f16, + DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ), + ] + + min_cc = 70 + max_cc = 75 + + dst_layouts = [LayoutType.TensorNCHW] + + dst_types = [DataType.f16] + + alignment_constraints = [128, 32, 16] + cuda_major = 10 + cuda_minor = 2 + + operations = [] + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [4, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + for layout in layouts: + for dst_type, dst_layout in zip(dst_types, dst_layouts): + for alignment_src in alignment_constraints: + operations += GenerateConv2d( + ConvType.DepthwiseConvolution, + ConvKind.Fprop, + tile_descriptions, + layout[0], + layout[1], + dst_layout, + dst_type, + min_cc, + alignment_src, + 16, + 16, + SpecialOptimizeDesc.NoneSpecialOpt, + ImplicitGemmMode.GemmTN, + False, + cuda_major, + cuda_minor, + ) + + return operations + # def GenerateGemv_Simt(args): - threadBlockShape_N = [128, 64, 32] - ldgBits_A = [128, 64, 32] - ldgBits_B = [128, 64, 32] - - layouts = [ - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), - ] - - math_instructions = [ - MathInstruction( \ - [1, 1, 1], \ - DataType.f32, DataType.f32, DataType.f32, \ - OpcodeClass.Simt, \ - MathOperation.multiply_add), - ] - - min_cc = 50 - - operations = [] - for math_inst in math_instructions: - for layout in layouts: - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - for threadblock_shape_n in threadBlockShape_N: - for align_a in ldgBits_A: - for align_b in ldgBits_B: - ldg_elements_a = align_a // DataTypeSize[math_inst.element_a] - ldg_elements_b = align_b // DataTypeSize[math_inst.element_b] - threadblock_shape_k = (256 * ldg_elements_a) // (threadblock_shape_n // ldg_elements_b) - threadblock_shape = [1, threadblock_shape_n, threadblock_shape_k] - thread_shape = [1, ldg_elements_b, ldg_elements_a] - - operations.append(GeneratesGemv(math_inst, \ - threadblock_shape, \ - thread_shape, \ - data_type, \ - layout[0], \ - layout[1], \ - layout[2], \ - min_cc, \ - align_a, \ - align_b)) - return operations + threadBlockShape_N = [128, 64, 32] + ldgBits_A = [128, 64, 32] + ldgBits_B = [128, 64, 32] + + layouts = [(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor)] + + math_instructions = [ + MathInstruction( + [1, 1, 1], + DataType.f32, + DataType.f32, + DataType.f32, + OpcodeClass.Simt, + MathOperation.multiply_add, + ) + ] + + min_cc = 50 + + operations = [] + for math_inst in math_instructions: + for layout in layouts: + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + for threadblock_shape_n in threadBlockShape_N: + for align_a in ldgBits_A: + for align_b in ldgBits_B: + ldg_elements_a = align_a // DataTypeSize[math_inst.element_a] + ldg_elements_b = align_b // DataTypeSize[math_inst.element_b] + threadblock_shape_k = (256 * ldg_elements_a) // ( + threadblock_shape_n // ldg_elements_b + ) + threadblock_shape = [ + 1, + threadblock_shape_n, + threadblock_shape_k, + ] + thread_shape = [1, ldg_elements_b, ldg_elements_a] + + operations.append( + GeneratesGemv( + math_inst, + threadblock_shape, + thread_shape, + data_type, + layout[0], + layout[1], + layout[2], + min_cc, + align_a, + align_b, + ) + ) + return operations + # def GeneratesGemm_TensorOp_1688(args): - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 8], \ - DataType.f16, DataType.f16, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - MathInstruction( \ - [16, 8, 8], \ - DataType.f16, DataType.f16, DataType.f16, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - ] - - min_cc = 75 - max_cc = 1024 - - alignment_constraints = [8, 4, 2, - #1 - ] - cuda_major = 10 - cuda_minor = 2 - - operations = [] - for math_inst in math_instructions: - for layout in layouts: - for align in alignment_constraints: - tile_descriptions = [ - TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), -## comment some configuration to reduce compilation time and binary size -# TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), -# TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), -# TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - math_inst.element_accumulator, - ] - - for tile in tile_descriptions: - operations += GeneratesGemm(tile, \ - data_type, \ - layout[0], \ - layout[1], \ - layout[2], \ - min_cc, \ - align * 16, \ - align * 16, \ - align * 16, \ - cuda_major, \ - cuda_minor) - return operations + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt + ] + + math_instructions = [ + MathInstruction( + [16, 8, 8], + DataType.f16, + DataType.f16, + DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ), + MathInstruction( + [16, 8, 8], + DataType.f16, + DataType.f16, + DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [ + 8, + 4, + 2, + # 1 + ] + cuda_major = 10 + cuda_minor = 2 + + operations = [] + for math_inst in math_instructions: + for layout in layouts: + for align in alignment_constraints: + tile_descriptions = [ + TileDescription( + [256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc + ), + ## comment some configuration to reduce compilation time and binary size + # TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + # TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + # TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + for tile in tile_descriptions: + operations += GeneratesGemm( + tile, + data_type, + layout[0], + layout[1], + layout[2], + min_cc, + align * 16, + align * 16, + align * 16, + cuda_major, + cuda_minor, + ) + return operations + # def GeneratesGemm_TensorOp_884(args): - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt - ] - - math_instructions = [ - MathInstruction( \ - [8, 8, 4], \ - DataType.f16, DataType.f16, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - MathInstruction( \ - [8, 8, 4], \ - DataType.f16, DataType.f16, DataType.f16, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - ] - - min_cc = 70 - max_cc = 75 - - alignment_constraints = [8, 4, 2, - # 1 - ] - cuda_major = 10 - cuda_minor = 2 - - operations = [] - for math_inst in math_instructions: - for layout in layouts: - for align in alignment_constraints: - tile_descriptions = [ - TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), -## comment some configuration to reduce compilation time and binary size -# TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), -# TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), -# TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - math_inst.element_accumulator, - ] - - for tile in tile_descriptions: - operations += GeneratesGemm(tile, \ - data_type, \ - layout[0], \ - layout[1], \ - layout[2], \ - min_cc, \ - align * 16, \ - align * 16, \ - align * 16, \ - cuda_major, \ - cuda_minor) - - return operations + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt + ] + + math_instructions = [ + MathInstruction( + [8, 8, 4], + DataType.f16, + DataType.f16, + DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ), + MathInstruction( + [8, 8, 4], + DataType.f16, + DataType.f16, + DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ), + ] + + min_cc = 70 + max_cc = 75 + + alignment_constraints = [ + 8, + 4, + 2, + # 1 + ] + cuda_major = 10 + cuda_minor = 2 + + operations = [] + for math_inst in math_instructions: + for layout in layouts: + for align in alignment_constraints: + tile_descriptions = [ + TileDescription( + [256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc + ), + ## comment some configuration to reduce compilation time and binary size + # TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + # TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + # TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + for tile in tile_descriptions: + operations += GeneratesGemm( + tile, + data_type, + layout[0], + layout[1], + layout[2], + min_cc, + align * 16, + align * 16, + align * 16, + cuda_major, + cuda_minor, + ) + + return operations + # def GenerateConv2dOperations(args): - if args.type == "simt": - return GenerateConv2d_Simt(args) - elif args.type == "tensorop8816": - return GenerateConv2d_TensorOp_8816(args) - else: - assert args.type == "tensorop8832", "operation conv2d only support" \ - "simt, tensorop8816 and tensorop8832. (got:{})".format(args.type) - return GenerateConv2d_TensorOp_8832(args) + if args.type == "simt": + return GenerateConv2d_Simt(args) + elif args.type == "tensorop8816": + return GenerateConv2d_TensorOp_8816(args) + else: + assert args.type == "tensorop8832", ( + "operation conv2d only support" + "simt, tensorop8816 and tensorop8832. (got:{})".format(args.type) + ) + return GenerateConv2d_TensorOp_8832(args) + def GenerateDeconvOperations(args): - if args.type == "simt": - return GenerateDeconv_Simt(args) - else: - assert args.type == "tensorop8816", "operation deconv only support" \ - "simt and tensorop8816. (got:{})".format(args.type) - return GenerateDeconv_TensorOp_8816(args) + if args.type == "simt": + return GenerateDeconv_Simt(args) + else: + assert args.type == "tensorop8816", ( + "operation deconv only support" + "simt and tensorop8816. (got:{})".format(args.type) + ) + return GenerateDeconv_TensorOp_8816(args) + + +def GenerateDwconv2dFpropOperations(args): + if args.type == "simt": + return GenerateDwconv2dFprop_Simt(args) + else: + assert args.type == "tensorop884", ( + "operation dwconv2d fprop only support" + "simt, tensorop884. (got:{})".format(args.type) + ) + return GenerateDwconv2dFprop_TensorOp_884(args) + def GenerateGemmOperations(args): - if args.type == "tensorop884": - return GeneratesGemm_TensorOp_884(args) - elif args.type == "tensorop1688": - return GeneratesGemm_TensorOp_1688(args) - else: - assert args.type == "simt", "operation gemm only support" \ - "simt. (got:{})".format(args.type) - return GenerateGemm_Simt(args) + if args.type == "tensorop884": + return GeneratesGemm_TensorOp_884(args) + elif args.type == "tensorop1688": + return GeneratesGemm_TensorOp_1688(args) + else: + assert ( + args.type == "simt" + ), "operation gemm only support" "simt. (got:{})".format(args.type) + return GenerateGemm_Simt(args) + def GenerateGemvOperations(args): - assert args.type == "simt", "operation gemv only support" \ - "simt. (got:{})".format(args.type) - return GenerateGemv_Simt(args) + assert args.type == "simt", "operation gemv only support" "simt. (got:{})".format( + args.type + ) + return GenerateGemv_Simt(args) + ################################################################################################### ################################################################################################### if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Generates device kernel registration code for CUTLASS Kernels") - parser.add_argument("--operations", type=str, choices=['gemm', 'gemv', 'conv2d', 'deconv'], - required=True, help="Specifies the operation to generate (gemm, gemv, conv2d, deconv)") - parser.add_argument("output", type=str, help="output directory for CUTLASS kernel files") - parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832', 'tensorop884', 'tensorop1688'], - default='simt', help="kernel type of CUTLASS kernel generator") - - gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl" - short_path = (platform.system() == "Windows" or platform.system().find('NT') >= 0) and ('true'!= os.getenv("CUTLASS_WITH_LONG_PATH", default='False').lower()) - args = parser.parse_args() - - if args.operations == "gemm": - operations = GenerateGemmOperations(args) - elif args.operations == "gemv": - operations = GenerateGemvOperations(args) - elif args.operations == "conv2d": - operations = GenerateConv2dOperations(args) - elif args.operations == "deconv": - operations = GenerateDeconvOperations(args) - - if args.operations == "conv2d" or args.operations == "deconv": - for operation in operations: - with EmitConvSingleKernelWrapper(args.output, operation, short_path) as emitter: - emitter.emit() - elif args.operations == "gemm": - for operation in operations: - with EmitGemmSingleKernelWrapper(args.output, operation, short_path) as emitter: - emitter.emit() - elif args.operations == "gemv": - for operation in operations: - with EmitGemvSingleKernelWrapper(args.output, operation, gemv_wrapper_path, short_path) as emitter: - emitter.emit() - - if args.operations != "gemv": - GenerateManifest(args, operations, args.output) + parser = argparse.ArgumentParser( + description="Generates device kernel registration code for CUTLASS Kernels" + ) + parser.add_argument( + "--operations", + type=str, + choices=[ + "gemm", + "gemv", + "conv2d", + "deconv", + "dwconv2d_fprop", + "dwconv2d_dgrad", + "dwconv2d_wgrad", + ], + required=True, + help="Specifies the operation to generate (gemm, gemv, conv2d, deconv, dwconv2d_fprop, dwconv2d_dgrad, dwconv2d_wgrad)", + ) + parser.add_argument( + "output", type=str, help="output directory for CUTLASS kernel files" + ) + parser.add_argument( + "--type", + type=str, + choices=["simt", "tensorop8816", "tensorop8832", "tensorop884", "tensorop1688"], + default="simt", + help="kernel type of CUTLASS kernel generator", + ) + + gemv_wrapper_path = ( + "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl" + ) + short_path = ( + platform.system() == "Windows" or platform.system().find("NT") >= 0 + ) and ("true" != os.getenv("CUTLASS_WITH_LONG_PATH", default="False").lower()) + args = parser.parse_args() + + if args.operations == "gemm": + operations = GenerateGemmOperations(args) + elif args.operations == "gemv": + operations = GenerateGemvOperations(args) + elif args.operations == "conv2d": + operations = GenerateConv2dOperations(args) + elif args.operations == "deconv": + operations = GenerateDeconvOperations(args) + elif args.operations == "dwconv2d_fprop": + operations = GenerateDwconv2dFpropOperations(args) + elif args.operations == "dwconv2d_dgrad": + pass + elif args.operations == "dwconv2d_wgrad": + pass + + if ( + args.operations == "conv2d" + or args.operations == "deconv" + or args.operations == "dwconv2d_fprop" + or args.operations == "dwconv2d_dgrad" + or args.operations == "dwconv2d_wgrad" + ): + for operation in operations: + with EmitConvSingleKernelWrapper( + args.output, operation, short_path + ) as emitter: + emitter.emit() + elif args.operations == "gemm": + for operation in operations: + with EmitGemmSingleKernelWrapper( + args.output, operation, short_path + ) as emitter: + emitter.emit() + elif args.operations == "gemv": + for operation in operations: + with EmitGemvSingleKernelWrapper( + args.output, operation, gemv_wrapper_path, short_path + ) as emitter: + emitter.emit() + + if args.operations != "gemv": + GenerateManifest(args, operations, args.output) # ################################################################################################### - diff --git a/dnn/scripts/cutlass_generator/library.py b/dnn/scripts/cutlass_generator/library.py index 466ddc25..60f2f3af 100644 --- a/dnn/scripts/cutlass_generator/library.py +++ b/dnn/scripts/cutlass_generator/library.py @@ -12,635 +12,682 @@ import enum # The following block implements enum.auto() for Python 3.5 variants that don't include it such # as the default 3.5.2 on Ubuntu 16.04. -# +# # https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility try: - from enum import auto as enum_auto -except ImportError: - __cutlass_library_auto_enum = 0 - def enum_auto() -> int: - global __cutlass_library_auto_enum - i = __cutlass_library_auto_enum - __cutlass_library_auto_enum += 1 - return i + from enum import auto as enum_auto +except ImportError: + __cutlass_library_auto_enum = 0 + + def enum_auto() -> int: + global __cutlass_library_auto_enum + i = __cutlass_library_auto_enum + __cutlass_library_auto_enum += 1 + return i + ################################################################################################### # class GeneratorTarget(enum.Enum): - Library = enum_auto() + Library = enum_auto() + + # -GeneratorTargetNames = { - GeneratorTarget.Library: 'library' -} +GeneratorTargetNames = {GeneratorTarget.Library: "library"} # ################################################################################################### # class DataType(enum.Enum): - b1 = enum_auto() - u4 = enum_auto() - u8 = enum_auto() - u16 = enum_auto() - u32 = enum_auto() - u64 = enum_auto() - s4 = enum_auto() - s8 = enum_auto() - s16 = enum_auto() - s32 = enum_auto() - s64 = enum_auto() - f16 = enum_auto() - bf16 = enum_auto() - f32 = enum_auto() - tf32 = enum_auto() - f64 = enum_auto() - cf16 = enum_auto() - cbf16 = enum_auto() - cf32 = enum_auto() - ctf32 = enum_auto() - cf64 = enum_auto() - cs4 = enum_auto() - cs8 = enum_auto() - cs16 = enum_auto() - cs32 = enum_auto() - cs64 = enum_auto() - cu4 = enum_auto() - cu8 = enum_auto() - cu16 = enum_auto() - cu32 = enum_auto() - cu64 = enum_auto() - invalid = enum_auto() + b1 = enum_auto() + u4 = enum_auto() + u8 = enum_auto() + u16 = enum_auto() + u32 = enum_auto() + u64 = enum_auto() + s4 = enum_auto() + s8 = enum_auto() + s16 = enum_auto() + s32 = enum_auto() + s64 = enum_auto() + f16 = enum_auto() + bf16 = enum_auto() + f32 = enum_auto() + tf32 = enum_auto() + f64 = enum_auto() + cf16 = enum_auto() + cbf16 = enum_auto() + cf32 = enum_auto() + ctf32 = enum_auto() + cf64 = enum_auto() + cs4 = enum_auto() + cs8 = enum_auto() + cs16 = enum_auto() + cs32 = enum_auto() + cs64 = enum_auto() + cu4 = enum_auto() + cu8 = enum_auto() + cu16 = enum_auto() + cu32 = enum_auto() + cu64 = enum_auto() + invalid = enum_auto() + # ShortDataTypeNames = { - DataType.s32: 'i', - DataType.f16: 'h', - DataType.f32: 's', - DataType.f64: 'd', - DataType.cf32: 'c', - DataType.cf64: 'z', + DataType.s32: "i", + DataType.f16: "h", + DataType.f32: "s", + DataType.f64: "d", + DataType.cf32: "c", + DataType.cf64: "z", } # DataTypeNames = { - DataType.b1: "b1", - DataType.u4: "u4", - DataType.u8: "u8", - DataType.u16: "u16", - DataType.u32: "u32", - DataType.u64: "u64", - DataType.s4: "s4", - DataType.s8: "s8", - DataType.s16: "s16", - DataType.s32: "s32", - DataType.s64: "s64", - DataType.f16: "f16", - DataType.bf16: "bf16", - DataType.f32: "f32", - DataType.tf32: "tf32", - DataType.f64: "f64", - DataType.cf16: "cf16", - DataType.cbf16: "cbf16", - DataType.cf32: "cf32", - DataType.ctf32: "ctf32", - DataType.cf64: "cf64", - DataType.cu4: "cu4", - DataType.cu8: "cu8", - DataType.cu16: "cu16", - DataType.cu32: "cu32", - DataType.cu64: "cu64", - DataType.cs4: "cs4", - DataType.cs8: "cs8", - DataType.cs16: "cs16", - DataType.cs32: "cs32", - DataType.cs64: "cs64", + DataType.b1: "b1", + DataType.u4: "u4", + DataType.u8: "u8", + DataType.u16: "u16", + DataType.u32: "u32", + DataType.u64: "u64", + DataType.s4: "s4", + DataType.s8: "s8", + DataType.s16: "s16", + DataType.s32: "s32", + DataType.s64: "s64", + DataType.f16: "f16", + DataType.bf16: "bf16", + DataType.f32: "f32", + DataType.tf32: "tf32", + DataType.f64: "f64", + DataType.cf16: "cf16", + DataType.cbf16: "cbf16", + DataType.cf32: "cf32", + DataType.ctf32: "ctf32", + DataType.cf64: "cf64", + DataType.cu4: "cu4", + DataType.cu8: "cu8", + DataType.cu16: "cu16", + DataType.cu32: "cu32", + DataType.cu64: "cu64", + DataType.cs4: "cs4", + DataType.cs8: "cs8", + DataType.cs16: "cs16", + DataType.cs32: "cs32", + DataType.cs64: "cs64", } DataTypeTag = { - DataType.b1: "cutlass::uint1b_t", - DataType.u4: "cutlass::uint4b_t", - DataType.u8: "uint8_t", - DataType.u16: "uint16_t", - DataType.u32: "uint32_t", - DataType.u64: "uint64_t", - DataType.s4: "cutlass::int4b_t", - DataType.s8: "int8_t", - DataType.s16: "int16_t", - DataType.s32: "int32_t", - DataType.s64: "int64_t", - DataType.f16: "cutlass::half_t", - DataType.bf16: "cutlass::bfloat16_t", - DataType.f32: "float", - DataType.tf32: "cutlass::tfloat32_t", - DataType.f64: "double", - DataType.cf16: "cutlass::complex", - DataType.cbf16: "cutlass::complex", - DataType.cf32: "cutlass::complex", - DataType.ctf32: "cutlass::complex", - DataType.cf64: "cutlass::complex", - DataType.cu4: "cutlass::complex", - DataType.cu8: "cutlass::complex", - DataType.cu16: "cutlass::complex", - DataType.cu32: "cutlass::complex", - DataType.cu64: "cutlass::complex", - DataType.cs4: "cutlass::complex", - DataType.cs8: "cutlass::complex", - DataType.cs16: "cutlass::complex", - DataType.cs32: "cutlass::complex", - DataType.cs64: "cutlass::complex", + DataType.b1: "cutlass::uint1b_t", + DataType.u4: "cutlass::uint4b_t", + DataType.u8: "uint8_t", + DataType.u16: "uint16_t", + DataType.u32: "uint32_t", + DataType.u64: "uint64_t", + DataType.s4: "cutlass::int4b_t", + DataType.s8: "int8_t", + DataType.s16: "int16_t", + DataType.s32: "int32_t", + DataType.s64: "int64_t", + DataType.f16: "cutlass::half_t", + DataType.bf16: "cutlass::bfloat16_t", + DataType.f32: "float", + DataType.tf32: "cutlass::tfloat32_t", + DataType.f64: "double", + DataType.cf16: "cutlass::complex", + DataType.cbf16: "cutlass::complex", + DataType.cf32: "cutlass::complex", + DataType.ctf32: "cutlass::complex", + DataType.cf64: "cutlass::complex", + DataType.cu4: "cutlass::complex", + DataType.cu8: "cutlass::complex", + DataType.cu16: "cutlass::complex", + DataType.cu32: "cutlass::complex", + DataType.cu64: "cutlass::complex", + DataType.cs4: "cutlass::complex", + DataType.cs8: "cutlass::complex", + DataType.cs16: "cutlass::complex", + DataType.cs32: "cutlass::complex", + DataType.cs64: "cutlass::complex", } DataTypeSize = { - DataType.b1: 1, - DataType.u4: 4, - DataType.u8: 4, - DataType.u16: 16, - DataType.u32: 32, - DataType.u64: 64, - DataType.s4: 4, - DataType.s8: 8, - DataType.s16: 16, - DataType.s32: 32, - DataType.s64: 64, - DataType.f16: 16, - DataType.bf16: 16, - DataType.f32: 32, - DataType.tf32: 32, - DataType.f64: 64, - DataType.cf16: 32, - DataType.cbf16: 32, - DataType.cf32: 64, - DataType.ctf32: 32, - DataType.cf64: 128, - DataType.cu4: 8, - DataType.cu8: 16, - DataType.cu16: 32, - DataType.cu32: 64, - DataType.cu64: 128, - DataType.cs4: 8, - DataType.cs8: 16, - DataType.cs16: 32, - DataType.cs32: 64, - DataType.cs64: 128, + DataType.b1: 1, + DataType.u4: 4, + DataType.u8: 4, + DataType.u16: 16, + DataType.u32: 32, + DataType.u64: 64, + DataType.s4: 4, + DataType.s8: 8, + DataType.s16: 16, + DataType.s32: 32, + DataType.s64: 64, + DataType.f16: 16, + DataType.bf16: 16, + DataType.f32: 32, + DataType.tf32: 32, + DataType.f64: 64, + DataType.cf16: 32, + DataType.cbf16: 32, + DataType.cf32: 64, + DataType.ctf32: 32, + DataType.cf64: 128, + DataType.cu4: 8, + DataType.cu8: 16, + DataType.cu16: 32, + DataType.cu32: 64, + DataType.cu64: 128, + DataType.cs4: 8, + DataType.cs8: 16, + DataType.cs16: 32, + DataType.cs32: 64, + DataType.cs64: 128, } ################################################################################################### # class ComplexTransform(enum.Enum): - none = enum_auto() - conj = enum_auto() + none = enum_auto() + conj = enum_auto() + # ComplexTransformTag = { - ComplexTransform.none: 'cutlass::ComplexTransform::kNone', - ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate', + ComplexTransform.none: "cutlass::ComplexTransform::kNone", + ComplexTransform.conj: "cutlass::ComplexTransform::kConjugate", } # RealComplexBijection = [ - (DataType.f16, DataType.cf16), - (DataType.f32, DataType.cf32), - (DataType.f64, DataType.cf64), + (DataType.f16, DataType.cf16), + (DataType.f32, DataType.cf32), + (DataType.f64, DataType.cf64), ] # def is_complex(data_type): - for r, c in RealComplexBijection: - if data_type == c: - return True - return False + for r, c in RealComplexBijection: + if data_type == c: + return True + return False + # def get_complex_from_real(real_type): - for r, c in RealComplexBijection: - if real_type == r: - return c - return DataType.invalid + for r, c in RealComplexBijection: + if real_type == r: + return c + return DataType.invalid + # def get_real_from_complex(complex_type): - for r, c in RealComplexBijection: - if complex_type == c: - return r - return DataType.invalid + for r, c in RealComplexBijection: + if complex_type == c: + return r + return DataType.invalid + # class ComplexMultiplyOp(enum.Enum): - multiply_add = enum_auto() - gaussian = enum_auto() + multiply_add = enum_auto() + gaussian = enum_auto() + ################################################################################################### # class MathOperation(enum.Enum): - multiply_add = enum_auto() - multiply_add_saturate = enum_auto() - xor_popc = enum_auto() - multiply_add_fast_bf16 = enum_auto() - multiply_add_fast_f16 = enum_auto() - multiply_add_complex = enum_auto() - multiply_add_complex_gaussian = enum_auto() + multiply_add = enum_auto() + multiply_add_saturate = enum_auto() + xor_popc = enum_auto() + multiply_add_fast_bf16 = enum_auto() + multiply_add_fast_f16 = enum_auto() + multiply_add_complex = enum_auto() + multiply_add_complex_gaussian = enum_auto() + # MathOperationTag = { - MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd', - MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate', - MathOperation.xor_popc: 'cutlass::arch::OpXorPopc', - MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16', - MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16', - MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex', - MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex', + MathOperation.multiply_add: "cutlass::arch::OpMultiplyAdd", + MathOperation.multiply_add_saturate: "cutlass::arch::OpMultiplyAddSaturate", + MathOperation.xor_popc: "cutlass::arch::OpXorPopc", + MathOperation.multiply_add_fast_bf16: "cutlass::arch::OpMultiplyAddFastBF16", + MathOperation.multiply_add_fast_f16: "cutlass::arch::OpMultiplyAddFastF16", + MathOperation.multiply_add_complex: "cutlass::arch::OpMultiplyAddComplex", + MathOperation.multiply_add_complex_gaussian: "cutlass::arch::OpMultiplyAddGaussianComplex", } ################################################################################################### # class LayoutType(enum.Enum): - ColumnMajor = enum_auto() - RowMajor = enum_auto() - ColumnMajorInterleaved2 = enum_auto() - RowMajorInterleaved2 = enum_auto() - ColumnMajorInterleaved32 = enum_auto() - RowMajorInterleaved32 = enum_auto() - ColumnMajorInterleaved64 = enum_auto() - RowMajorInterleaved64 = enum_auto() - TensorNHWC = enum_auto() - TensorNDHWC = enum_auto() - TensorNCHW = enum_auto() - TensorNGHWC = enum_auto() - TensorNC4HW4 = enum_auto() - TensorC4RSK4 = enum_auto() - TensorNC8HW8 = enum_auto() - TensorNC16HW16 = enum_auto() - TensorNC32HW32 = enum_auto() - TensorNC64HW64 = enum_auto() - TensorC32RSK32 = enum_auto() - TensorC64RSK64 = enum_auto() - TensorK4RSC4 = enum_auto() - TensorCK4RS4 = enum_auto() - TensorCK8RS8 = enum_auto() - TensorCK16RS16 = enum_auto() + ColumnMajor = enum_auto() + RowMajor = enum_auto() + ColumnMajorInterleaved2 = enum_auto() + RowMajorInterleaved2 = enum_auto() + ColumnMajorInterleaved32 = enum_auto() + RowMajorInterleaved32 = enum_auto() + ColumnMajorInterleaved64 = enum_auto() + RowMajorInterleaved64 = enum_auto() + TensorNHWC = enum_auto() + TensorNDHWC = enum_auto() + TensorNCHW = enum_auto() + TensorNGHWC = enum_auto() + TensorNC4HW4 = enum_auto() + TensorC4RSK4 = enum_auto() + TensorNC8HW8 = enum_auto() + TensorNC16HW16 = enum_auto() + TensorNC32HW32 = enum_auto() + TensorNC64HW64 = enum_auto() + TensorC32RSK32 = enum_auto() + TensorC64RSK64 = enum_auto() + TensorK4RSC4 = enum_auto() + TensorCK4RS4 = enum_auto() + TensorCK8RS8 = enum_auto() + TensorCK16RS16 = enum_auto() + # LayoutTag = { - LayoutType.ColumnMajor: 'cutlass::layout::ColumnMajor', - LayoutType.RowMajor: 'cutlass::layout::RowMajor', - LayoutType.ColumnMajorInterleaved2: 'cutlass::layout::ColumnMajorInterleaved<2>', - LayoutType.RowMajorInterleaved2: 'cutlass::layout::RowMajorInterleaved<2>', - LayoutType.ColumnMajorInterleaved32: 'cutlass::layout::ColumnMajorInterleaved<32>', - LayoutType.RowMajorInterleaved32: 'cutlass::layout::RowMajorInterleaved<32>', - LayoutType.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>', - LayoutType.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>', - LayoutType.TensorNHWC: 'cutlass::layout::TensorNHWC', - LayoutType.TensorNDHWC: 'cutlass::layout::TensorNDHWC', - LayoutType.TensorNCHW: 'cutlass::layout::TensorNCHW', - LayoutType.TensorNGHWC: 'cutlass::layout::TensorNGHWC', - LayoutType.TensorNC4HW4: 'cutlass::layout::TensorNCxHWx<4>', - LayoutType.TensorC4RSK4: 'cutlass::layout::TensorCxRSKx<4>', - LayoutType.TensorNC8HW8: 'cutlass::layout::TensorNCxHWx<8>', - LayoutType.TensorNC16HW16: 'cutlass::layout::TensorNCxHWx<16>', - LayoutType.TensorNC32HW32: 'cutlass::layout::TensorNCxHWx<32>', - LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>', - LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>', - LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>', - LayoutType.TensorK4RSC4: 'cutlass::layout::TensorKxRSCx<4>', - LayoutType.TensorCK4RS4: 'cutlass::layout::TensorCKxRSx<4>', - LayoutType.TensorCK8RS8: 'cutlass::layout::TensorCKxRSx<8>', - LayoutType.TensorCK16RS16: 'cutlass::layout::TensorCKxRSx<16>', + LayoutType.ColumnMajor: "cutlass::layout::ColumnMajor", + LayoutType.RowMajor: "cutlass::layout::RowMajor", + LayoutType.ColumnMajorInterleaved2: "cutlass::layout::ColumnMajorInterleaved<2>", + LayoutType.RowMajorInterleaved2: "cutlass::layout::RowMajorInterleaved<2>", + LayoutType.ColumnMajorInterleaved32: "cutlass::layout::ColumnMajorInterleaved<32>", + LayoutType.RowMajorInterleaved32: "cutlass::layout::RowMajorInterleaved<32>", + LayoutType.ColumnMajorInterleaved64: "cutlass::layout::ColumnMajorInterleaved<64>", + LayoutType.RowMajorInterleaved64: "cutlass::layout::RowMajorInterleaved<64>", + LayoutType.TensorNHWC: "cutlass::layout::TensorNHWC", + LayoutType.TensorNDHWC: "cutlass::layout::TensorNDHWC", + LayoutType.TensorNCHW: "cutlass::layout::TensorNCHW", + LayoutType.TensorNGHWC: "cutlass::layout::TensorNGHWC", + LayoutType.TensorNC4HW4: "cutlass::layout::TensorNCxHWx<4>", + LayoutType.TensorC4RSK4: "cutlass::layout::TensorCxRSKx<4>", + LayoutType.TensorNC8HW8: "cutlass::layout::TensorNCxHWx<8>", + LayoutType.TensorNC16HW16: "cutlass::layout::TensorNCxHWx<16>", + LayoutType.TensorNC32HW32: "cutlass::layout::TensorNCxHWx<32>", + LayoutType.TensorC32RSK32: "cutlass::layout::TensorCxRSKx<32>", + LayoutType.TensorNC64HW64: "cutlass::layout::TensorNCxHWx<64>", + LayoutType.TensorC64RSK64: "cutlass::layout::TensorCxRSKx<64>", + LayoutType.TensorK4RSC4: "cutlass::layout::TensorKxRSCx<4>", + LayoutType.TensorCK4RS4: "cutlass::layout::TensorCKxRSx<4>", + LayoutType.TensorCK8RS8: "cutlass::layout::TensorCKxRSx<8>", + LayoutType.TensorCK16RS16: "cutlass::layout::TensorCKxRSx<16>", } # TransposedLayout = { - LayoutType.ColumnMajor: LayoutType.RowMajor, - LayoutType.RowMajor: LayoutType.ColumnMajor, - LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2, - LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2, - LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32, - LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32, - LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64, - LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64, - LayoutType.TensorNHWC: LayoutType.TensorNHWC + LayoutType.ColumnMajor: LayoutType.RowMajor, + LayoutType.RowMajor: LayoutType.ColumnMajor, + LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2, + LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2, + LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32, + LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32, + LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64, + LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64, + LayoutType.TensorNHWC: LayoutType.TensorNHWC, } # ShortLayoutTypeNames = { - LayoutType.ColumnMajor: 'n', - LayoutType.ColumnMajorInterleaved32: 'n2', - LayoutType.ColumnMajorInterleaved32: 'n32', - LayoutType.ColumnMajorInterleaved64: 'n64', - LayoutType.RowMajor: 't', - LayoutType.RowMajorInterleaved2: 't2', - LayoutType.RowMajorInterleaved32: 't32', - LayoutType.RowMajorInterleaved64: 't64', - LayoutType.TensorNHWC: 'nhwc', - LayoutType.TensorNDHWC: 'ndhwc', - LayoutType.TensorNCHW: 'nchw', - LayoutType.TensorNGHWC: 'nghwc', - LayoutType.TensorNC4HW4: 'nc4hw4', - LayoutType.TensorC4RSK4: 'c4rsk4', - LayoutType.TensorNC8HW8: 'nc8hw8', - LayoutType.TensorNC16HW16: 'nc16hw16', - LayoutType.TensorNC32HW32: 'nc32hw32', - LayoutType.TensorNC64HW64: 'nc64hw64', - LayoutType.TensorC32RSK32: 'c32rsk32', - LayoutType.TensorC64RSK64: 'c64rsk64', - LayoutType.TensorK4RSC4: 'k4rsc4', - LayoutType.TensorCK4RS4: 'ck4rs4', - LayoutType.TensorCK8RS8: 'ck8rs8', - LayoutType.TensorCK16RS16: 'ck16rs16', + LayoutType.ColumnMajor: "n", + LayoutType.ColumnMajorInterleaved32: "n2", + LayoutType.ColumnMajorInterleaved32: "n32", + LayoutType.ColumnMajorInterleaved64: "n64", + LayoutType.RowMajor: "t", + LayoutType.RowMajorInterleaved2: "t2", + LayoutType.RowMajorInterleaved32: "t32", + LayoutType.RowMajorInterleaved64: "t64", + LayoutType.TensorNHWC: "nhwc", + LayoutType.TensorNDHWC: "ndhwc", + LayoutType.TensorNCHW: "nchw", + LayoutType.TensorNGHWC: "nghwc", + LayoutType.TensorNC4HW4: "nc4hw4", + LayoutType.TensorC4RSK4: "c4rsk4", + LayoutType.TensorNC8HW8: "nc8hw8", + LayoutType.TensorNC16HW16: "nc16hw16", + LayoutType.TensorNC32HW32: "nc32hw32", + LayoutType.TensorNC64HW64: "nc64hw64", + LayoutType.TensorC32RSK32: "c32rsk32", + LayoutType.TensorC64RSK64: "c64rsk64", + LayoutType.TensorK4RSC4: "k4rsc4", + LayoutType.TensorCK4RS4: "ck4rs4", + LayoutType.TensorCK8RS8: "ck8rs8", + LayoutType.TensorCK16RS16: "ck16rs16", } # ShortComplexLayoutNames = { - (LayoutType.ColumnMajor, ComplexTransform.none): 'n', - (LayoutType.ColumnMajor, ComplexTransform.conj): 'c', - (LayoutType.RowMajor, ComplexTransform.none): 't', - (LayoutType.RowMajor, ComplexTransform.conj): 'h' + (LayoutType.ColumnMajor, ComplexTransform.none): "n", + (LayoutType.ColumnMajor, ComplexTransform.conj): "c", + (LayoutType.RowMajor, ComplexTransform.none): "t", + (LayoutType.RowMajor, ComplexTransform.conj): "h", } ################################################################################################### # class OpcodeClass(enum.Enum): - Simt = enum_auto() - TensorOp = enum_auto() - WmmaTensorOp = enum_auto() + Simt = enum_auto() + TensorOp = enum_auto() + WmmaTensorOp = enum_auto() + OpcodeClassNames = { - OpcodeClass.Simt: 'simt', - OpcodeClass.TensorOp: 'tensorop', - OpcodeClass.WmmaTensorOp: 'wmma_tensorop', + OpcodeClass.Simt: "simt", + OpcodeClass.TensorOp: "tensorop", + OpcodeClass.WmmaTensorOp: "wmma_tensorop", } OpcodeClassTag = { - OpcodeClass.Simt: 'cutlass::arch::OpClassSimt', - OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp', - OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp', + OpcodeClass.Simt: "cutlass::arch::OpClassSimt", + OpcodeClass.TensorOp: "cutlass::arch::OpClassTensorOp", + OpcodeClass.WmmaTensorOp: "cutlass::arch::OpClassWmmaTensorOp", } ################################################################################################### # class OperationKind(enum.Enum): - Gemm = enum_auto() - Conv2d = enum_auto() + Gemm = enum_auto() + Conv2d = enum_auto() + # -OperationKindNames = { - OperationKind.Gemm: 'gemm' - , OperationKind.Conv2d: 'conv2d' -} +OperationKindNames = {OperationKind.Gemm: "gemm", OperationKind.Conv2d: "conv2d"} -# +# class Target(enum.Enum): - library = enum_auto() + library = enum_auto() + ArchitectureNames = { - 50: 'maxwell', - 60: 'pascal', - 61: 'pascal', - 70: 'volta', - 75: 'turing', - 80: 'ampere', + 50: "maxwell", + 60: "pascal", + 61: "pascal", + 70: "volta", + 75: "turing", + 80: "ampere", } ################################################################################################### # def SubstituteTemplate(template, values): - text = template - changed = True - while changed: - changed = False - for key, value in values.items(): - regex = "\\$\\{%s\\}" % key - newtext = re.sub(regex, value, text) - if newtext != text: - changed = True - text = newtext - return text + text = template + changed = True + while changed: + changed = False + for key, value in values.items(): + regex = "\\$\\{%s\\}" % key + newtext = re.sub(regex, value, text) + if newtext != text: + changed = True + text = newtext + return text + ################################################################################################### # class GemmKind(enum.Enum): - Gemm = enum_auto() - Sparse = enum_auto() - Universal = enum_auto() - PlanarComplex = enum_auto() - PlanarComplexArray = enum_auto() - SplitKParallel = enum_auto() - GemvBatchedStrided = enum_auto() + Gemm = enum_auto() + Sparse = enum_auto() + Universal = enum_auto() + PlanarComplex = enum_auto() + PlanarComplexArray = enum_auto() + SplitKParallel = enum_auto() + GemvBatchedStrided = enum_auto() + # GemmKindNames = { - GemmKind.Gemm: "gemm", - GemmKind.Sparse: "spgemm", - GemmKind.Universal: "gemm", - GemmKind.PlanarComplex: "gemm_planar_complex", - GemmKind.PlanarComplexArray: "gemm_planar_complex_array", - GemmKind.SplitKParallel: "gemm_split_k_parallel", - GemmKind.GemvBatchedStrided: "gemv_batched_strided", + GemmKind.Gemm: "gemm", + GemmKind.Sparse: "spgemm", + GemmKind.Universal: "gemm", + GemmKind.PlanarComplex: "gemm_planar_complex", + GemmKind.PlanarComplexArray: "gemm_planar_complex_array", + GemmKind.SplitKParallel: "gemm_split_k_parallel", + GemmKind.GemvBatchedStrided: "gemv_batched_strided", } # class EpilogueFunctor(enum.Enum): - LinearCombination = enum_auto() - LinearCombinationClamp = enum_auto() - BiasAddLinearCombination = enum_auto() - BiasAddLinearCombinationRelu = enum_auto() - BiasAddLinearCombinationHSwish = enum_auto() - BiasAddLinearCombinationClamp = enum_auto() - BiasAddLinearCombinationReluClamp = enum_auto() - BiasAddLinearCombinationHSwishClamp = enum_auto() + LinearCombination = enum_auto() + LinearCombinationClamp = enum_auto() + BiasAddLinearCombination = enum_auto() + BiasAddLinearCombinationRelu = enum_auto() + BiasAddLinearCombinationHSwish = enum_auto() + BiasAddLinearCombinationClamp = enum_auto() + BiasAddLinearCombinationReluClamp = enum_auto() + BiasAddLinearCombinationHSwishClamp = enum_auto() # EpilogueFunctorTag = { - EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination', - EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp', - EpilogueFunctor.BiasAddLinearCombination: 'cutlass::epilogue::thread::BiasAddLinearCombination', - EpilogueFunctor.BiasAddLinearCombinationRelu: 'cutlass::epilogue::thread::BiasAddLinearCombinationRelu', - EpilogueFunctor.BiasAddLinearCombinationHSwish: 'cutlass::epilogue::thread::BiasAddLinearCombinationHSwish', - EpilogueFunctor.BiasAddLinearCombinationClamp: 'cutlass::epilogue::thread::BiasAddLinearCombinationClamp', - EpilogueFunctor.BiasAddLinearCombinationReluClamp: 'cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp', - EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: 'cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp', + EpilogueFunctor.LinearCombination: "cutlass::epilogue::thread::LinearCombination", + EpilogueFunctor.LinearCombinationClamp: "cutlass::epilogue::thread::LinearCombinationClamp", + EpilogueFunctor.BiasAddLinearCombination: "cutlass::epilogue::thread::BiasAddLinearCombination", + EpilogueFunctor.BiasAddLinearCombinationRelu: "cutlass::epilogue::thread::BiasAddLinearCombinationRelu", + EpilogueFunctor.BiasAddLinearCombinationHSwish: "cutlass::epilogue::thread::BiasAddLinearCombinationHSwish", + EpilogueFunctor.BiasAddLinearCombinationClamp: "cutlass::epilogue::thread::BiasAddLinearCombinationClamp", + EpilogueFunctor.BiasAddLinearCombinationReluClamp: "cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp", + EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: "cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp", } # ShortEpilogueNames = { - EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: 'hswish', - EpilogueFunctor.BiasAddLinearCombinationReluClamp: 'relu', - EpilogueFunctor.BiasAddLinearCombinationClamp: 'id', - EpilogueFunctor.BiasAddLinearCombinationHSwish: 'hswish', - EpilogueFunctor.BiasAddLinearCombinationRelu: 'relu', - EpilogueFunctor.BiasAddLinearCombination: 'id', + EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: "hswish", + EpilogueFunctor.BiasAddLinearCombinationReluClamp: "relu", + EpilogueFunctor.BiasAddLinearCombinationClamp: "id", + EpilogueFunctor.BiasAddLinearCombinationHSwish: "hswish", + EpilogueFunctor.BiasAddLinearCombinationRelu: "relu", + EpilogueFunctor.BiasAddLinearCombination: "id", } - - - - # class SwizzlingFunctor(enum.Enum): - Identity1 = enum_auto() - Identity2 = enum_auto() - Identity4 = enum_auto() - Identity8 = enum_auto() - ConvFpropNCxHWx = enum_auto() - ConvFpropTrans = enum_auto() - ConvDgradNCxHWx = enum_auto() - ConvDgradTrans = enum_auto() + Identity1 = enum_auto() + Identity2 = enum_auto() + Identity4 = enum_auto() + Identity8 = enum_auto() + ConvFpropNCxHWx = enum_auto() + ConvFpropTrans = enum_auto() + ConvDgradNCxHWx = enum_auto() + ConvDgradTrans = enum_auto() + DepthwiseConvolutionFprop = enum_auto() + DepthwiseConvolutionDgrad = enum_auto() + DepthwiseConvolutionWgrad = enum_auto() + # SwizzlingFunctorTag = { - SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>', - SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>', - SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', - SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', - SwizzlingFunctor.ConvFpropNCxHWx: 'cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle', - SwizzlingFunctor.ConvFpropTrans: 'cutlass::conv::threadblock::ConvolutionFpropTransThreadblockSwizzle', - SwizzlingFunctor.ConvDgradNCxHWx: 'cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle', - SwizzlingFunctor.ConvDgradTrans: 'cutlass::conv::threadblock::ConvolutionDgradTransThreadblockSwizzle', + SwizzlingFunctor.Identity1: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>", + SwizzlingFunctor.Identity2: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>", + SwizzlingFunctor.Identity4: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>", + SwizzlingFunctor.Identity8: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>", + SwizzlingFunctor.ConvFpropNCxHWx: "cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle", + SwizzlingFunctor.ConvFpropTrans: "cutlass::conv::threadblock::ConvolutionFpropTransThreadblockSwizzle", + SwizzlingFunctor.ConvDgradNCxHWx: "cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle", + SwizzlingFunctor.ConvDgradTrans: "cutlass::conv::threadblock::ConvolutionDgradTransThreadblockSwizzle", + SwizzlingFunctor.DepthwiseConvolutionFprop: "cutlass::conv::threadblock::DepthwiseConvolutionFpropThreadblockSwizzle", + SwizzlingFunctor.DepthwiseConvolutionDgrad: "cutlass::conv::threadblock::DepthwiseConvolutionDgradThreadblockSwizzle", + SwizzlingFunctor.DepthwiseConvolutionWgrad: "cutlass::conv::threadblock::DepthwiseConvolutionWgradThreadblockSwizzle", } ################################################################################################### + class ConvType(enum.Enum): - Convolution = enum_auto() - BatchConvolution = enum_auto() - Local = enum_auto() - LocalShare = enum_auto() + Convolution = enum_auto() + BatchConvolution = enum_auto() + Local = enum_auto() + LocalShare = enum_auto() + DepthwiseConvolution = enum_auto() + ConvTypeTag = { - ConvType.Convolution: 'cutlass::conv::ConvType::kConvolution', - ConvType.BatchConvolution: 'cutlass::conv::ConvType::kBatchConvolution', - ConvType.Local: 'cutlass::conv::ConvType::kLocal', - ConvType.LocalShare : 'cutlass::conv::ConvType::kLocalShare', + ConvType.Convolution: "cutlass::conv::ConvType::kConvolution", + ConvType.BatchConvolution: "cutlass::conv::ConvType::kBatchConvolution", + ConvType.Local: "cutlass::conv::ConvType::kLocal", + ConvType.LocalShare: "cutlass::conv::ConvType::kLocalShare", + ConvType.DepthwiseConvolution: "cutlass::conv::ConvType::kDepthwiseConvolution", } # class ConvKind(enum.Enum): - Fprop = enum_auto() - Dgrad = enum_auto() - Wgrad = enum_auto() + Fprop = enum_auto() + Dgrad = enum_auto() + Wgrad = enum_auto() + # ConvKindTag = { - ConvKind.Fprop: 'cutlass::conv::Operator::kFprop', - ConvKind.Dgrad: 'cutlass::conv::Operator::kDgrad', - ConvKind.Wgrad: 'cutlass::conv::Operator::kWgrad' + ConvKind.Fprop: "cutlass::conv::Operator::kFprop", + ConvKind.Dgrad: "cutlass::conv::Operator::kDgrad", + ConvKind.Wgrad: "cutlass::conv::Operator::kWgrad", } ConvKindNames = { - ConvKind.Fprop: 'fprop', - ConvKind.Dgrad: 'dgrad', - ConvKind.Wgrad: 'wgrad', + ConvKind.Fprop: "fprop", + ConvKind.Dgrad: "dgrad", + ConvKind.Wgrad: "wgrad", } # class IteratorAlgorithm(enum.Enum): - Analytic = enum_auto() - Optimized = enum_auto() + Analytic = enum_auto() + Optimized = enum_auto() + # IteratorAlgorithmTag = { - IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic', - IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized', + IteratorAlgorithm.Analytic: "cutlass::conv::IteratorAlgorithm::kAnalytic", + IteratorAlgorithm.Optimized: "cutlass::conv::IteratorAlgorithm::kOptimized", } IteratorAlgorithmNames = { - IteratorAlgorithm.Analytic: 'analytic', - IteratorAlgorithm.Optimized: 'optimized', + IteratorAlgorithm.Analytic: "analytic", + IteratorAlgorithm.Optimized: "optimized", } # class StrideSupport(enum.Enum): - Strided = enum_auto() - Unity = enum_auto() + Strided = enum_auto() + Unity = enum_auto() + # StrideSupportTag = { - StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided', - StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity', + StrideSupport.Strided: "cutlass::conv::StrideSupport::kStrided", + StrideSupport.Unity: "cutlass::conv::StrideSupport::kUnity", } -StrideSupportNames = { - StrideSupport.Strided: '', - StrideSupport.Unity: 'unity_stride', -} +StrideSupportNames = {StrideSupport.Strided: "", StrideSupport.Unity: "unity_stride"} + class SpecialOptimizeDesc(enum.Enum): - NoneSpecialOpt = enum_auto() - ConvFilterUnity = enum_auto() - DeconvDoubleUpsampling = enum_auto() + NoneSpecialOpt = enum_auto() + ConvFilterUnity = enum_auto() + DeconvDoubleUpsampling = enum_auto() + SpecialOptimizeDescNames = { - SpecialOptimizeDesc.NoneSpecialOpt: 'none', - SpecialOptimizeDesc.ConvFilterUnity: 'conv_filter_unity', - SpecialOptimizeDesc.DeconvDoubleUpsampling: 'deconv_double_upsampling', + SpecialOptimizeDesc.NoneSpecialOpt: "none", + SpecialOptimizeDesc.ConvFilterUnity: "conv_filter_unity", + SpecialOptimizeDesc.DeconvDoubleUpsampling: "deconv_double_upsampling", } SpecialOptimizeDescTag = { - SpecialOptimizeDesc.NoneSpecialOpt: 'cutlass::conv::SpecialOptimizeDesc::NONE', - SpecialOptimizeDesc.ConvFilterUnity: 'cutlass::conv::SpecialOptimizeDesc::CONV_FILTER_UNITY', - SpecialOptimizeDesc.DeconvDoubleUpsampling: 'cutlass::conv::SpecialOptimizeDesc::DECONV_DOUBLE_UPSAMPLING', + SpecialOptimizeDesc.NoneSpecialOpt: "cutlass::conv::SpecialOptimizeDesc::NONE", + SpecialOptimizeDesc.ConvFilterUnity: "cutlass::conv::SpecialOptimizeDesc::CONV_FILTER_UNITY", + SpecialOptimizeDesc.DeconvDoubleUpsampling: "cutlass::conv::SpecialOptimizeDesc::DECONV_DOUBLE_UPSAMPLING", } class ImplicitGemmMode(enum.Enum): - GemmNT = enum_auto() - GemmTN = enum_auto() + GemmNT = enum_auto() + GemmTN = enum_auto() + ImplicitGemmModeNames = { - ImplicitGemmMode.GemmNT: 'gemm_nt', - ImplicitGemmMode.GemmTN: 'gemm_tn', + ImplicitGemmMode.GemmNT: "gemm_nt", + ImplicitGemmMode.GemmTN: "gemm_tn", } ImplicitGemmModeTag = { - ImplicitGemmMode.GemmNT: 'cutlass::conv::ImplicitGemmMode::GEMM_NT', - ImplicitGemmMode.GemmTN: 'cutlass::conv::ImplicitGemmMode::GEMM_TN', + ImplicitGemmMode.GemmNT: "cutlass::conv::ImplicitGemmMode::GEMM_NT", + ImplicitGemmMode.GemmTN: "cutlass::conv::ImplicitGemmMode::GEMM_TN", } ################################################################################################### # class MathInstruction: - def __init__(self, instruction_shape, element_a, element_b, element_accumulator, opcode_class, math_operation = MathOperation.multiply_add): - self.instruction_shape = instruction_shape - self.element_a = element_a - self.element_b = element_b - self.element_accumulator = element_accumulator - self.opcode_class = opcode_class - self.math_operation = math_operation + def __init__( + self, + instruction_shape, + element_a, + element_b, + element_accumulator, + opcode_class, + math_operation=MathOperation.multiply_add, + ): + self.instruction_shape = instruction_shape + self.element_a = element_a + self.element_b = element_b + self.element_accumulator = element_accumulator + self.opcode_class = opcode_class + self.math_operation = math_operation # class TileDescription: + def __init__( + self, + threadblock_shape, + stages, + warp_count, + math_instruction, + min_compute, + max_compute, + ): + self.threadblock_shape = threadblock_shape + self.stages = stages + self.warp_count = warp_count + self.math_instruction = math_instruction + self.minimum_compute_capability = min_compute + self.maximum_compute_capability = max_compute + + def procedural_name(self): + return "%dx%d_%dx%d" % ( + self.threadblock_shape[0], + self.threadblock_shape[1], + self.threadblock_shape[2], + self.stages, + ) - def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute): - self.threadblock_shape = threadblock_shape - self.stages = stages - self.warp_count = warp_count - self.math_instruction = math_instruction - self.minimum_compute_capability = min_compute - self.maximum_compute_capability = max_compute - - def procedural_name(self): - return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) # class TensorDescription: - def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none): - self.element = element - self.layout = layout - self.alignment = alignment - self.complex_transform = complex_transform + def __init__( + self, element, layout, alignment=1, complex_transform=ComplexTransform.none + ): + self.element = element + self.layout = layout + self.alignment = alignment + self.complex_transform = complex_transform + ################################################################################################### + class GlobalCnt: - cnt = 0 \ No newline at end of file + cnt = 0 diff --git a/dnn/scripts/cutlass_generator/manifest.py b/dnn/scripts/cutlass_generator/manifest.py index 33aafb8a..ee27e668 100644 --- a/dnn/scripts/cutlass_generator/manifest.py +++ b/dnn/scripts/cutlass_generator/manifest.py @@ -10,24 +10,25 @@ import shutil from library import * from gemm_operation import * -from conv2d_operation import * +from conv2d_operation import * ################################################################################################### + class EmitOperationKindLibrary: - def __init__(self, generated_path, kind, args): - self.generated_path = generated_path - self.kind = kind - self.args = args + def __init__(self, generated_path, kind, args): + self.generated_path = generated_path + self.kind = kind + self.args = args - self.emitters = { - OperationKind.Gemm: EmitGemmConfigurationLibrary - , OperationKind.Conv2d: EmitConv2dConfigurationLibrary - } + self.emitters = { + OperationKind.Gemm: EmitGemmConfigurationLibrary, + OperationKind.Conv2d: EmitConv2dConfigurationLibrary, + } - self.configurations = []; + self.configurations = [] - self.header_template =""" + self.header_template = """ /* Generated by manifest.py - Do not edit. */ @@ -42,17 +43,19 @@ namespace library { /////////////////////////////////////////////////////////////////////////////////////////////////// """ - self.entry_template = """ + self.entry_template = """ // // Entry point to construct operations // void initialize_all_${operation_name}_operations(Manifest &manifest) { """ - self.configuration_prototype_template = "void initialize_${configuration_name}(Manifest &manifest);\n" - self.configuration_template =" initialize_${configuration_name}(manifest);\n" + self.configuration_prototype_template = ( + "void initialize_${configuration_name}(Manifest &manifest);\n" + ) + self.configuration_template = " initialize_${configuration_name}(manifest);\n" - self.epilogue_template =""" + self.epilogue_template = """ } @@ -63,91 +66,118 @@ void initialize_all_${operation_name}_operations(Manifest &manifest) { """ - # - def __enter__(self): - self.operation_path = os.path.join(self.generated_path, OperationKindNames[self.kind]) - os.mkdir(self.operation_path) + # + def __enter__(self): + self.operation_path = os.path.join( + self.generated_path, OperationKindNames[self.kind] + ) + os.mkdir(self.operation_path) + + self.top_level_path = os.path.join( + self.operation_path, "all_%s_operations.cu" % OperationKindNames[self.kind] + ) + + self.top_level_file = open(self.top_level_path, "w") + self.top_level_file.write(self.header_template) - self.top_level_path = os.path.join(self.operation_path, "all_%s_operations.cu" % OperationKindNames[self.kind]) + self.source_files = [self.top_level_path] - self.top_level_file = open(self.top_level_path, "w") - self.top_level_file.write(self.header_template) + return self - self.source_files = [self.top_level_path,] + # + def emit(self, configuration_name, operations): - return self + with self.emitters[self.kind]( + self.operation_path, configuration_name + ) as configuration_emitter: + for operation in operations: + configuration_emitter.emit(operation) - # - def emit(self, configuration_name, operations): + self.source_files.append(configuration_emitter.configuration_path) - with self.emitters[self.kind](self.operation_path, configuration_name) as configuration_emitter: - for operation in operations: - configuration_emitter.emit(operation) - - self.source_files.append(configuration_emitter.configuration_path) + self.configurations.append(configuration_name) + self.top_level_file.write( + SubstituteTemplate( + self.configuration_prototype_template, + {"configuration_name": configuration_name}, + ) + ) - self.configurations.append(configuration_name) - self.top_level_file.write(SubstituteTemplate(self.configuration_prototype_template, {'configuration_name': configuration_name} )) + # + def __exit__(self, exception_type, exception_value, traceback): + self.top_level_file.write( + SubstituteTemplate( + self.entry_template, {"operation_name": OperationKindNames[self.kind]} + ) + ) - # - def __exit__(self, exception_type, exception_value, traceback): - self.top_level_file.write(SubstituteTemplate(self.entry_template, {'operation_name': OperationKindNames[self.kind]})) + for configuration_name in self.configurations: + self.top_level_file.write( + SubstituteTemplate( + self.configuration_template, + {"configuration_name": configuration_name}, + ) + ) - for configuration_name in self.configurations: - self.top_level_file.write(SubstituteTemplate(self.configuration_template, {'configuration_name': configuration_name})) + self.top_level_file.write(self.epilogue_template) + self.top_level_file.close() - self.top_level_file.write(self.epilogue_template) - self.top_level_file.close() ################################################################################################### ################################################################################################### + class Options: - def __init__(self): - pass + def __init__(self): + pass + ################################################################################################### # class Manifest: - # - def __init__(self, args): - self.operations = {} - self.args = args + # + def __init__(self, args): + self.operations = {} + self.args = args + + architectures = ( + args.architectures.split(";") if len(args.architectures) else ["50"] + ) + self.compute_capabilities = [int(x) for x in architectures] - architectures = args.architectures.split(';') if len(args.architectures) else ['50',] - self.compute_capabilities = [int(x) for x in architectures] - - self.selected_kernels = [] - - if args.operations == 'all': - self.operations_enabled = [] - else: + self.selected_kernels = [] - operations_list = [ - OperationKind.Gemm - , OperationKind.Conv2d - ] + if args.operations == "all": + self.operations_enabled = [] + else: - self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')] + operations_list = [OperationKind.Gemm, OperationKind.Conv2d] - if args.kernels == 'all': - self.kernel_names = [] - else: - self.kernel_names = [x for x in args.kernels.split(',') if x != ''] + self.operations_enabled = [ + x + for x in operations_list + if OperationKindNames[x] in args.operations.split(",") + ] - self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != ''] + if args.kernels == "all": + self.kernel_names = [] + else: + self.kernel_names = [x for x in args.kernels.split(",") if x != ""] - if args.kernel_filter_file is None: - self.kernel_filter_list = [] - else: - self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file) + self.ignore_kernel_names = [ + x for x in args.ignore_kernels.split(",") if x != "" + ] + if args.kernel_filter_file is None: + self.kernel_filter_list = [] + else: + self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file) - self.operation_count = 0 - self.operations_by_name = {} - self.top_level_prologue = ''' + self.operation_count = 0 + self.operations_by_name = {} + self.top_level_prologue = """ #include "cutlass/library/library.h" #include "cutlass/library/manifest.h" @@ -159,208 +189,241 @@ ${prototypes} void initialize_all(Manifest &manifest) { -''' - self.top_level_reserve = ' manifest.reserve(${operation_count});\n\n' - self.top_level_epilogue = ''' +""" + self.top_level_reserve = " manifest.reserve(${operation_count});\n\n" + self.top_level_epilogue = """ } } // namespace library } // namespace cutlass -''' - - - def get_kernel_filters (self, kernelListFile): - if os.path.isfile(kernelListFile): - with open(kernelListFile, 'r') as fileReader: - lines = [line.rstrip() for line in fileReader if not line.startswith("#")] - - lines = [re.compile(line) for line in lines if line] - return lines - else: - return [] +""" + def get_kernel_filters(self, kernelListFile): + if os.path.isfile(kernelListFile): + with open(kernelListFile, "r") as fileReader: + lines = [ + line.rstrip() for line in fileReader if not line.startswith("#") + ] + lines = [re.compile(line) for line in lines if line] + return lines + else: + return [] - def filter_out_kernels(self, kernel_name, kernel_filter_list): + def filter_out_kernels(self, kernel_name, kernel_filter_list): - for kernel_filter_re in kernel_filter_list: - if kernel_filter_re.search(kernel_name) is not None: - return True - - return False + for kernel_filter_re in kernel_filter_list: + if kernel_filter_re.search(kernel_name) is not None: + return True - - # - def _filter_string_matches(self, filter_string, haystack): - ''' Returns true if all substrings appear in the haystack in order''' - substrings = filter_string.split('*') - for sub in substrings: - idx = haystack.find(sub) - if idx < 0: return False - haystack = haystack[idx + len(sub):] - return True - - # - def filter(self, operation): - ''' Filtering operations based on various criteria''' - - # filter based on compute capability - enabled = False - for cc in self.compute_capabilities: - if cc >= operation.tile_description.minimum_compute_capability and \ - cc <= operation.tile_description.maximum_compute_capability: - - enabled = True - break - - if not enabled: - return False - - if len(self.operations_enabled) and not operation.operation_kind in self.operations_enabled: - return False - - # eliminate duplicates - if operation.procedural_name() in self.operations_by_name.keys(): - return False - - # Filter based on list of valid substrings - if len(self.kernel_names): - name = operation.procedural_name() - enabled = False - - # compare against the include list - for name_substr in self.kernel_names: - if self._filter_string_matches(name_substr, name): - enabled = True - break - - # compare against the exclude list - for name_substr in self.ignore_kernel_names: - if self._filter_string_matches(name_substr, name): - enabled = False - break - - if len(self.kernel_filter_list) > 0: + + # + def _filter_string_matches(self, filter_string, haystack): + """ Returns true if all substrings appear in the haystack in order""" + substrings = filter_string.split("*") + for sub in substrings: + idx = haystack.find(sub) + if idx < 0: + return False + haystack = haystack[idx + len(sub) :] + return True + + # + def filter(self, operation): + """ Filtering operations based on various criteria""" + + # filter based on compute capability enabled = False - if self.filter_out_kernels(operation.procedural_name(), self.kernel_filter_list): - enabled = True + for cc in self.compute_capabilities: + if ( + cc >= operation.tile_description.minimum_compute_capability + and cc <= operation.tile_description.maximum_compute_capability + ): + + enabled = True + break + + if not enabled: + return False + + if ( + len(self.operations_enabled) + and not operation.operation_kind in self.operations_enabled + ): + return False + + # eliminate duplicates + if operation.procedural_name() in self.operations_by_name.keys(): + return False + + # Filter based on list of valid substrings + if len(self.kernel_names): + name = operation.procedural_name() + enabled = False + + # compare against the include list + for name_substr in self.kernel_names: + if self._filter_string_matches(name_substr, name): + enabled = True + break + + # compare against the exclude list + for name_substr in self.ignore_kernel_names: + if self._filter_string_matches(name_substr, name): + enabled = False + break + + if len(self.kernel_filter_list) > 0: + enabled = False + if self.filter_out_kernels( + operation.procedural_name(), self.kernel_filter_list + ): + enabled = True + + # todo: filter based on compute data type + return enabled + + # + + # + def append(self, operation): + """ + Inserts the operation. + operation_kind -> configuration_name -> [] + """ - # todo: filter based on compute data type - return enabled - # + if self.filter(operation): - # - def append(self, operation): - ''' - Inserts the operation. + self.selected_kernels.append(operation.procedural_name()) - operation_kind -> configuration_name -> [] - ''' + self.operations_by_name[operation.procedural_name()] = operation - if self.filter(operation): - - self.selected_kernels.append(operation.procedural_name()) + # add the configuration + configuration_name = operation.configuration_name() - self.operations_by_name[operation.procedural_name()] = operation + if operation.operation_kind not in self.operations.keys(): + self.operations[operation.operation_kind] = {} - # add the configuration - configuration_name = operation.configuration_name() + if ( + configuration_name + not in self.operations[operation.operation_kind].keys() + ): + self.operations[operation.operation_kind][configuration_name] = [] - if operation.operation_kind not in self.operations.keys(): - self.operations[operation.operation_kind] = {} + self.operations[operation.operation_kind][configuration_name].append( + operation + ) + self.operation_count += 1 - if configuration_name not in self.operations[operation.operation_kind].keys(): - self.operations[operation.operation_kind][configuration_name] = [] + # - self.operations[operation.operation_kind][configuration_name].append(operation) - self.operation_count += 1 - # + # + def emit(self, target=GeneratorTarget.Library): - # - def emit(self, target = GeneratorTarget.Library): + operation_emitters = {GeneratorTarget.Library: EmitOperationKindLibrary} - operation_emitters = { - GeneratorTarget.Library: EmitOperationKindLibrary - } + generated_path = os.path.join(self.args.curr_build_dir, "generated") - generated_path = os.path.join(self.args.curr_build_dir, 'generated') + # create generated/ + if os.path.exists(generated_path): + shutil.rmtree(generated_path) - # create generated/ - if os.path.exists(generated_path): - shutil.rmtree(generated_path) + os.mkdir(generated_path) - os.mkdir(generated_path) + source_files = [] - source_files = [] + top_level_path = os.path.join(generated_path, "initialize_all.cpp") + with open(top_level_path, "w") as top_level_file: - top_level_path = os.path.join(generated_path, 'initialize_all.cpp') - with open(top_level_path, 'w') as top_level_file: + if target == GeneratorTarget.Library: + source_files.append(top_level_path) - if target == GeneratorTarget.Library: - source_files.append(top_level_path) + prototypes = [] + for operation_kind, configurations in self.operations.items(): + prototypes.append( + SubstituteTemplate( + "void initialize_all_${operation_kind}_operations(Manifest &manifest);", + {"operation_kind": OperationKindNames[operation_kind]}, + ) + ) - prototypes = [] - for operation_kind, configurations in self.operations.items(): - prototypes.append(SubstituteTemplate( - "void initialize_all_${operation_kind}_operations(Manifest &manifest);", - {'operation_kind': OperationKindNames[operation_kind]})) + top_level_file.write( + SubstituteTemplate( + self.top_level_prologue, {"prototypes": "\n".join(prototypes)} + ) + ) - top_level_file.write(SubstituteTemplate(self.top_level_prologue, - {'prototypes': "\n".join(prototypes)})) + top_level_file.write( + SubstituteTemplate( + self.top_level_reserve, + {"operation_count": str(self.operation_count)}, + ) + ) - top_level_file.write(SubstituteTemplate( - self.top_level_reserve, {'operation_count': str(self.operation_count)})) + # for each operation kind, emit initializer for all configurations + for operation_kind, configurations in self.operations.items(): - # for each operation kind, emit initializer for all configurations - for operation_kind, configurations in self.operations.items(): - - with operation_emitters[target](generated_path, operation_kind, self.args) as operation_kind_emitter: - for configuration_name, operations in configurations.items(): - operation_kind_emitter.emit(configuration_name, operations) + with operation_emitters[target]( + generated_path, operation_kind, self.args + ) as operation_kind_emitter: + for configuration_name, operations in configurations.items(): + operation_kind_emitter.emit(configuration_name, operations) - source_files += operation_kind_emitter.source_files + source_files += operation_kind_emitter.source_files - top_level_file.write(SubstituteTemplate( - " initialize_all_${operation_kind}_operations(manifest);\n", - {'operation_kind': OperationKindNames[operation_kind]})) + top_level_file.write( + SubstituteTemplate( + " initialize_all_${operation_kind}_operations(manifest);\n", + {"operation_kind": OperationKindNames[operation_kind]}, + ) + ) - top_level_file.write(self.top_level_epilogue) + top_level_file.write(self.top_level_epilogue) - # write the manifest.cmake file containing paths from all targets - manifest_path = os.path.join(generated_path, "manifest.cmake") - with open(manifest_path, "w") as manifest_file: + # write the manifest.cmake file containing paths from all targets + manifest_path = os.path.join(generated_path, "manifest.cmake") + with open(manifest_path, "w") as manifest_file: - target_name = 'cutlass_library_objs' + target_name = "cutlass_library_objs" - target_text = SubstituteTemplate("""cutlass_target_sources( + target_text = SubstituteTemplate( + """cutlass_target_sources( ${target_name} BATCH_SOURCES ON PRIVATE -""", { 'target_name': target_name}) +""", + {"target_name": target_name}, + ) - manifest_file.write(target_text) + manifest_file.write(target_text) + + for source_file in source_files: + manifest_file.write(" %s\n" % str(source_file.replace("\\", "/"))) + manifest_file.write(")") + + # - for source_file in source_files: - manifest_file.write(" %s\n" % str(source_file.replace('\\', '/'))) - manifest_file.write(")") - # ################################################################################################### + def GenerateManifest(args, operations, output_dir): - assert isinstance(operations, list) - if len(operations) == 0: - return - op = operations[0] - required_cuda_ver_major = op.required_cuda_ver_major - required_cuda_ver_minor = op.required_cuda_ver_minor - - manifest_path = os.path.join(output_dir, "all_%s_%s_operations.cu" % (args.operations, args.type)) - f = open(manifest_path, "w") - f.write(""" + assert isinstance(operations, list) + if len(operations) == 0: + return + op = operations[0] + required_cuda_ver_major = op.required_cuda_ver_major + required_cuda_ver_minor = op.required_cuda_ver_minor + + manifest_path = os.path.join( + output_dir, "all_%s_%s_operations.cu" % (args.operations, args.type) + ) + f = open(manifest_path, "w") + f.write( + """ /* Generated by generator.py - Do not edit. */ @@ -374,24 +437,35 @@ def GenerateManifest(args, operations, output_dir): namespace cutlass { namespace library { -""" % (str(required_cuda_ver_major), str(required_cuda_ver_major), str(required_cuda_ver_minor))) - - for op in operations: - f.write("void initialize_%s(Manifest &manifest);\n" % op.procedural_name()) - - f.write(""" +""" + % ( + str(required_cuda_ver_major), + str(required_cuda_ver_major), + str(required_cuda_ver_minor), + ) + ) + + for op in operations: + f.write("void initialize_%s(Manifest &manifest);\n" % op.procedural_name()) + + f.write( + """ void initialize_all_%s_%s_operations(Manifest &manifest) { -""" % (args.operations, args.type)) +""" + % (args.operations, args.type) + ) - for op in operations: - f.write(" initialize_%s(manifest);\n" % op.procedural_name()) + for op in operations: + f.write(" initialize_%s(manifest);\n" % op.procedural_name()) - f.write(""" + f.write( + """ } } // namespace library } // namespace cutlass #endif -""") - f.close() +""" + ) + f.close() diff --git a/dnn/src/CMakeLists.txt b/dnn/src/CMakeLists.txt index d0566165..d9f28286 100644 --- a/dnn/src/CMakeLists.txt +++ b/dnn/src/CMakeLists.txt @@ -181,6 +181,8 @@ if(MGE_WITH_CUDA) gen_cutlass_kimpl(conv2d simt CUTLASS_SOURCES) gen_cutlass_kimpl(conv2d tensorop8816 CUTLASS_SOURCES) gen_cutlass_kimpl(conv2d tensorop8832 CUTLASS_SOURCES) + gen_cutlass_kimpl(dwconv2d_fprop simt CUTLASS_SOURCES) + gen_cutlass_kimpl(dwconv2d_fprop tensorop884 CUTLASS_SOURCES) list(APPEND SOURCES ${CUTLASS_SOURCES}) list(APPEND SOURCES ${CUSOURCES}) endif() diff --git a/dnn/src/cuda/conv_bias/algo.cpp b/dnn/src/cuda/conv_bias/algo.cpp index 6aa2578e..086746ad 100644 --- a/dnn/src/cuda/conv_bias/algo.cpp +++ b/dnn/src/cuda/conv_bias/algo.cpp @@ -92,6 +92,7 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { for (auto&& algo : int8_nchw4_dotprod) { all_algos.push_back(&algo); } + fill_dwconv_algos(); all_algos.push_back(&int8_chwn4_dotprod); all_algos.push_back(&fallback_nchw_qs8); for (size_t i = all_algo_size; i < all_algos.size(); ++i) { @@ -301,6 +302,32 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { } #endif +void ConvBiasForwardImpl::AlgoPack::fill_dwconv_algos() { + using AlgoParam = AlgoCutlassConvolutionBase::AlgoParam; + f32_implicit_bmm.emplace_back(AlgoParam{128, 128, 8, 32, 64, 8, 1, 1, 1, 2}); + f32_implicit_bmm.emplace_back(AlgoParam{128, 64, 8, 64, 32, 8, 1, 1, 1, 2}); + f32_implicit_bmm.emplace_back(AlgoParam{128, 32, 8, 64, 32, 8, 1, 1, 1, 2}); + f32_implicit_bmm.emplace_back(AlgoParam{32, 128, 8, 32, 64, 8, 1, 1, 1, 2}); + f32_implicit_bmm.emplace_back(AlgoParam{64, 128, 8, 64, 32, 8, 1, 1, 1, 2}); + f32_implicit_bmm.emplace_back(AlgoParam{64, 64, 8, 64, 32, 8, 1, 1, 1, 2}); + f32_implicit_bmm.emplace_back(AlgoParam{32, 64, 8, 32, 64, 8, 1, 1, 1, 2}); + f32_implicit_bmm.emplace_back(AlgoParam{32, 32, 8, 32, 32, 8, 1, 1, 1, 2}); + f32_implicit_bmm.emplace_back(AlgoParam{64, 32, 8, 64, 32, 8, 1, 1, 1, 2}); + for (auto&& algo : f32_implicit_bmm) { + all_algos.push_back(&algo); + } +#if CUDA_VERSION >= 10020 + f16_implicit_bmm.emplace_back(AlgoParam{128, 128, 32, 32, 32, 32, 8, 8, 4, 2}); + f16_implicit_bmm.emplace_back(AlgoParam{128, 256, 32, 64, 64, 32, 8, 8, 4, 2}); + f16_implicit_bmm.emplace_back(AlgoParam{128, 64, 32, 32, 32, 32, 8, 8, 4, 2}); + f16_implicit_bmm.emplace_back(AlgoParam{64, 128, 32, 32, 32, 32, 8, 8, 4, 2}); + f16_implicit_bmm.emplace_back(AlgoParam{64, 64, 32, 32, 32, 32, 8, 8, 4, 2}); + for (auto&& algo : f16_implicit_bmm) { + all_algos.push_back(&algo); + } +#endif +} + void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() { using AlgoParam = AlgoInt8NCHW4DotProdImplicitGemm::AlgoParam; int8_nchw4_dotprod.emplace_back(AlgoParam{128, 128, 32, 64, 32, 32, 1, 1, 4, 2}); diff --git a/dnn/src/cuda/conv_bias/algo.h b/dnn/src/cuda/conv_bias/algo.h index 682357de..ea0e6053 100644 --- a/dnn/src/cuda/conv_bias/algo.h +++ b/dnn/src/cuda/conv_bias/algo.h @@ -84,7 +84,9 @@ public: CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW32_IMMA_INT8, CUDA_IMPLICIT_GEMM_SASS_NCHW64_IMMA_INT4_INT4, CUDA_IMPLICIT_GEMM_SASS_NCHW64_IMMA_UINT4_INT4, - CUDA_FALLBACK_NCHW_INT4 + CUDA_FALLBACK_NCHW_INT4, + CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32, + CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16, }; using Mapper = std::unordered_map; @@ -503,6 +505,8 @@ public: * +----+--- AlgoInt4Int4NHWCIMMAImplicitGemm * +----+--- AlgoUInt4Int4NHWCIMMAImplicitGemm * + + * +--- AlgoFloat32NCHWImplicitBatchedGemm + * +--- AlgoFloat16NCHWHMMAImplicitBatchedGemm */ /* @@ -516,7 +520,13 @@ public: // corresponds to cutlass::conv::ConvType. we hope that algo.h does not // depend on cutlass headers - enum class ConvType { kConvolution, kBatchConvolution, kLocal, kLocalShare }; + enum class ConvType { + kConvolution, + kBatchConvolution, + kLocal, + kLocalShare, + kDepthwiseConvolution, + }; // common parameters for operation selection struct AlgoParam { @@ -558,7 +568,8 @@ public: size_t wo, size_t ph, size_t pw, size_t sh, size_t sw, size_t dh, size_t dw, const void* alpha, const void* beta, const void* gamma, const void* delta, const void* theta, const void* threshold, const void* dst_scale, - cudaStream_t stream, const void* extra_param = nullptr) const; + cudaStream_t stream, const void* extra_param = nullptr, + size_t groups = 1) const; protected: AlgoParam m_algo_param; @@ -992,6 +1003,54 @@ private: }; #endif +class ConvBiasForwardImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm final + : public AlgoCutlassConvolutionBase { +public: + AlgoFloat32NCHWFMAImplicitBatchedGemm(AlgoParam algo_param) + : AlgoCutlassConvolutionBase(algo_param) { + m_name = ConvBias::algo_name( + ssprintf( + "FLOAT32_NCHW_FMA_IMPLICIT_BATCHED_GEMM%s", + m_algo_param.to_string().c_str()), + ConvBias::DirectParam{}); + } + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override { + return 0; + } + void exec(const ExecArgs& args) const override; + const char* name() const override { return m_name.c_str(); }; + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32); + +private: + std::string m_name; +}; + +class ConvBiasForwardImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm final + : public AlgoCutlassConvolutionBase { +public: + AlgoFloat16NCHWHMMAImplicitBatchedGemm(AlgoParam algo_param) + : AlgoCutlassConvolutionBase(algo_param) { + m_name = ConvBias::algo_name( + ssprintf( + "FLOAT16_NCHW_HMMA_IMPLICIT_BATCHED_GEMM%s", + m_algo_param.to_string().c_str()), + ConvBias::DirectParam{}); + } + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override { + return 0; + } + void exec(const ExecArgs& args) const override; + const char* name() const override { return m_name.c_str(); }; + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16); + +private: + std::string m_name; +}; + class ConvBiasForwardImpl::AlgoBFloat16 final : public AlgoBase { public: bool is_available(const SizeArgs& args) const override; @@ -1048,6 +1107,8 @@ public: std::vector int4_int4_nhwc_imma; std::vector uint4_int4_nhwc_imma; #endif + std::vector f32_implicit_bmm; + std::vector f16_implicit_bmm; AlgoGroupConvGeneral group; AlgoBFloat16 bfloat16; @@ -1063,6 +1124,7 @@ private: #endif void fill_cudnn_algos(); void fill_dp4a_algos(); + void fill_dwconv_algos(); }; } // namespace cuda diff --git a/dnn/src/cuda/conv_bias/cutlass_convolution_base.cpp b/dnn/src/cuda/conv_bias/cutlass_convolution_base.cpp index d6dcf2de..ab578546 100644 --- a/dnn/src/cuda/conv_bias/cutlass_convolution_base.cpp +++ b/dnn/src/cuda/conv_bias/cutlass_convolution_base.cpp @@ -74,13 +74,18 @@ cutlass::conv::ConvType convert_conv_type(Base::ConvType conv_type) { return cutlass::conv::ConvType::kLocal; case Base::ConvType::kLocalShare: return cutlass::conv::ConvType::kLocalShare; + case Base::ConvType::kDepthwiseConvolution: + return cutlass::conv::ConvType::kDepthwiseConvolution; default: megdnn_assert(0, "invalid conv type"); } } -NumericTypeID convert_dtype(DTypeEnum dtype) { - switch (dtype) { +NumericTypeID convert_dtype(DType dtype) { + // just make convolution with no bias happy + if (!dtype.valid()) + return NumericTypeID::kF32; + switch (dtype.enumv()) { case DTypeEnum::Float32: return NumericTypeID::kF32; case DTypeEnum::Float16: @@ -100,6 +105,21 @@ NumericTypeID convert_dtype(DTypeEnum dtype) { } } +NumericTypeID get_accumulator_dtype( + DType dtype, const param::ConvBias::ComputeMode comp_mode) { + if (dtype.category() == DTypeCategory::QUANTIZED) { + return NumericTypeID::kS32; + } else { + megdnn_assert(dtype.category() == DTypeCategory::FLOAT); + if (comp_mode == param::ConvBias::ComputeMode::DEFAULT) { + return convert_dtype(dtype); + } else { + megdnn_assert(comp_mode == param::ConvBias::ComputeMode::FLOAT32); + return NumericTypeID::kF32; + } + } +} + struct LayoutPack { LayoutTypeID src; LayoutTypeID filter; @@ -149,6 +169,9 @@ LayoutPack get_layout_pack(const param::ConvBias::Format format, int access_type default: megdnn_assert(0, "invalid access_type"); } + case Format::NCHW: + return {LayoutTypeID::kTensorNCHW, LayoutTypeID::kTensorNCHW, + LayoutTypeID::kTensorNCHW, LayoutTypeID::kTensorNCHW}; default: megdnn_assert(0, "invalid format"); } @@ -177,6 +200,93 @@ EpilogueType get_epilogue_type(const param::ConvBias::NonlineMode mode, bool cla megdnn_assert(0, "invalid nonlinear mode"); } +std::pair get_tensor_alignment( + const param::ConvBias::Format format, const TensorLayout& src, + const TensorLayout& filter, const Base::AlgoParam& algo_param, + bool is_chanwise) { + int alignment_src = 0; + int alignment_filter = 0; + + using Format = param::ConvBias::Format; + + // get tensor alignment for tensor op operations + // for tensor op operations, the alignment is determined by the size of a vector + auto get_tensor_alignment_tensor_op = [&]() { + switch (format) { + /// case int8 + case Format::NCHW32: + case Format::NCHW32_NCHW4: + alignment_src = 16; + alignment_filter = 16; + break; + /// case int4 or uint4 + case Format::NCHW64: + alignment_src = 32; + alignment_filter = 32; + break; + case Format::NHWC: + alignment_src = alignment_filter = algo_param.access_size; + break; + default: + megdnn_throw("invalid format"); + }; + }; + + // get tensor alignment for dot product operations + // for integer dot product operations, alignment src is always 4 + // and the alignment filter is determined by the threadblock shape + auto get_tensor_alignment_dp4a = [&]() { + megdnn_assert( + format == Format::NCHW4 || format == Format::NCHW4_NCHW || + format == Format::NCHW4_NHWC || format == Format::NCHW4_NCHW32); + alignment_src = 4; + // determine alignment filter + constexpr int warp_size = 32; + int threads = warp_size * algo_param.threadblock_m * algo_param.threadblock_n * + algo_param.threadblock_k / + (algo_param.warp_m * algo_param.warp_n * algo_param.warp_k); + int threadblock_loads = filter.dtype.size( + algo_param.threadblock_m * algo_param.threadblock_n * + algo_param.threadblock_k); + int load_per_thread = threadblock_loads / threads; + if (load_per_thread >= 16) + alignment_filter = 16; + else if (load_per_thread >= 8) + alignment_filter = 8; + else { + megdnn_assert(load_per_thread >= 4); + alignment_filter = 4; + } + }; + + // get tensor alignment for depthwise convolution + auto get_tensor_alignment_dwconv2d_nchw = [&]() { + alignment_filter = 1; + size_t wi = src.dtype.size(src[3]); // width extent in bytes + for (size_t candidate : {16, 4, 2}) { + if (wi % candidate == 0) { + alignment_src = candidate; + break; + } + } + alignment_src /= src.dtype.size(1); + }; + + if (format == Format::NCHW32 || format == Format::NCHW32_NCHW4 || + format == Format::NCHW64 || format == Format::NCHW64) { + get_tensor_alignment_tensor_op(); + } else if ( + format == Format::NCHW4 || format == Format::NCHW4_NCHW || + format == Format::NCHW4_NHWC || format == Format::NCHW4_NCHW32) { + get_tensor_alignment_dp4a(); + } else { + /// the following is used for depthwise convolution + megdnn_assert(format == Format::NCHW && is_chanwise); + get_tensor_alignment_dwconv2d_nchw(); + } + megdnn_assert(alignment_src >= 1 && alignment_filter >= 1); + return {alignment_src, alignment_filter}; +} } // namespace const Operation* ConvBiasForwardImpl::AlgoCutlassConvolutionBase::get_cutlass_conv_op( @@ -185,23 +295,36 @@ const Operation* ConvBiasForwardImpl::AlgoCutlassConvolutionBase::get_cutlass_co auto&& param = args.opr->param(); auto layouts = get_layout_pack(param.format, m_algo_param.access_size); auto epilogue_type = get_epilogue_type( - param.nonlineMode, args.dst_layout->dtype.enumv() != DTypeEnum::Float32); + param.nonlineMode, + args.dst_layout->dtype.category() != DTypeCategory::FLOAT); cutlass::conv::SpecialOptimizeDesc special_optimization = (use_conv_filter_unity_opt) ? cutlass::conv::SpecialOptimizeDesc::CONV_FILTER_UNITY : cutlass::conv::SpecialOptimizeDesc::NONE; + int alignment_src, alignment_filter; + auto&& fm = args.filter_meta; + bool is_chanwise = param.sparse == param::ConvBias::Sparse::GROUP && fm.icpg == 1 && + fm.ocpg == 1; + std::tie(alignment_src, alignment_filter) = get_tensor_alignment( + param.format, *args.src_layout, *args.filter_layout, m_algo_param, + is_chanwise); + + auto accumulator_dtype = + get_accumulator_dtype(args.src_layout->dtype, param.compute_mode); + ConvolutionKey key{ convert_conv_op(conv_op), - convert_dtype(args.src_layout->dtype.enumv()), + convert_dtype(args.src_layout->dtype), layouts.src, - convert_dtype(args.filter_layout->dtype.enumv()), + convert_dtype(args.filter_layout->dtype), layouts.filter, - convert_dtype(args.dst_layout->dtype.enumv()), + convert_dtype(args.dst_layout->dtype), layouts.dst, - convert_dtype(args.bias_layout->dtype.enumv()), + convert_dtype(args.bias_layout->dtype), layouts.bias, + accumulator_dtype, convert_conv_type(conv_type), m_algo_param.threadblock_m, m_algo_param.threadblock_n, @@ -215,6 +338,8 @@ const Operation* ConvBiasForwardImpl::AlgoCutlassConvolutionBase::get_cutlass_co epilogue_type, m_algo_param.stage, special_optimization, + alignment_src, + alignment_filter, without_shared_load}; return Singleton::get().operation_table.find_op(key); @@ -227,13 +352,16 @@ void ConvBiasForwardImpl::AlgoCutlassConvolutionBase::execute_cutlass_conv_op( size_t pw, size_t sh, size_t sw, size_t dh, size_t dw, const void* alpha, const void* beta, const void* gamma, const void* delta, const void* theta, const void* threshold, const void* dst_scale, cudaStream_t stream, - const void* extra_param) const { + const void* extra_param, size_t groups) const { // gcc prints warnings when size_t values are implicitly narrowed to int cutlass::conv::Conv2dProblemSize problem_size{ - int(n), int(hi), int(wi), int(ci), - int(co), int(fh), int(fw), int(ho), - int(wo), int(ph), int(pw), int(sh), - int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation}; + int(n), int(hi), int(wi), int(ci), + int(co), int(fh), int(fw), int(ho), + int(wo), int(ph), int(pw), int(sh), + int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation, + 1, // split k slices, always 1 + int(groups), // groups + }; ConvolutionArguments conv_args{ problem_size, src, filter, bias, z, dst, alpha, diff --git a/dnn/src/cuda/conv_bias/implicit_batched_gemm_float16_nchw_hmma.cpp b/dnn/src/cuda/conv_bias/implicit_batched_gemm_float16_nchw_hmma.cpp new file mode 100644 index 00000000..74843e9d --- /dev/null +++ b/dnn/src/cuda/conv_bias/implicit_batched_gemm_float16_nchw_hmma.cpp @@ -0,0 +1,95 @@ +/** + * \file dnn/src/cuda/conv_bias/implicit_batched_gemm_float16_nchw_hmma.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/common/conv_bias.h" +#include "src/cuda/conv_bias/algo.h" +#include "src/cuda/convolution_helper/parameter.cuh" +#include "src/cuda/utils.h" + +using namespace megdnn; +using namespace cuda; + +bool ConvBiasForwardImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm::is_available( + const SizeArgs& args) const { +#define RETURN_IF_FALSE(stmt_) \ + if (!(stmt_)) \ + return false; + RETURN_IF_FALSE( + args.src_layout->is_contiguous() && args.dst_layout->is_contiguous()); + using Param = param::ConvBias; + using Format = Param::Format; + using Sparse = Param::Sparse; + using Mode = Param::Mode; + auto&& param = args.opr->param(); + auto&& fm = args.filter_meta; + RETURN_IF_FALSE( + param.format == Format::NCHW && + args.src_layout->dtype.enumv() == DTypeEnum::Float16 && + args.filter_layout->dtype.enumv() == DTypeEnum::Float16 && + args.dst_layout->dtype.enumv() == DTypeEnum::Float16); + RETURN_IF_FALSE( + args.bias_layout->ndim <= 0 || + (args.bias_layout->dtype.enumv() == DTypeEnum::Float16 && + check_bias_share_in_channel(*args.bias_layout, param.format))); + RETURN_IF_FALSE( + args.z_layout->ndim <= 0 || + args.z_layout->dtype.enumv() == DTypeEnum::Float16); + RETURN_IF_FALSE(param.sparse == Sparse::GROUP); + RETURN_IF_FALSE(param.mode == Mode::CROSS_CORRELATION); + // check if channelwise convolution + RETURN_IF_FALSE(fm.icpg == 1 && fm.ocpg == 1); + const auto* op = get_cutlass_conv_op( + args, ConvOperator::kFprop, ConvType::kDepthwiseConvolution, false, false); + RETURN_IF_FALSE(op != nullptr); + return true; +#undef RETURN_IF_FALSE +} + +void ConvBiasForwardImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm::exec( + const ExecArgs& args) const { + auto&& param = args.opr->param(); + auto&& fm = args.filter_meta; + size_t n = args.src_layout->operator[](0), hi = args.src_layout->operator[](2), + wi = args.src_layout->operator[](3); + size_t ho = args.dst_layout->operator[](2), wo = args.dst_layout->operator[](3); + size_t co = fm.group; + size_t ci = co; + // check if channelwise convolution + megdnn_assert(fm.icpg == 1 && fm.ocpg == 1); + auto&& stream = cuda_stream(args.opr->handle()); + + float alpha = 1.f; + float beta = args.bias_layout->ndim > 0 ? 1.f : 0.f; + void* bias_ptr = args.bias_layout->ndim > 0 ? args.bias_tensor->raw_ptr() : nullptr; + float gamma = args.z_layout->ndim > 0 ? 1.f : 0.f; + void* z_ptr = args.z_layout->ndim > 0 ? args.z_tensor->raw_ptr() : nullptr; + + // dummy parameters, used for quantization cases + float theta = 0.f; + float delta = 0.f; + float threshold = 0.f; + + const auto* op = get_cutlass_conv_op( + args, ConvOperator::kFprop, ConvType::kDepthwiseConvolution, false, false); + + UNPACK_CONV_PARAMETER(fm, param); + MARK_USED_VAR + execute_cutlass_conv_op( + op, args.src_tensor->raw_ptr(), args.filter_tensor->raw_ptr(), bias_ptr, + z_ptr, args.dst_tensor->raw_ptr(), nullptr, n, hi, wi, ci, co, fh, fw, ho, + wo, ph, pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, + &threshold, nullptr, stream, nullptr, fm.group); + + after_kernel_launch(); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/conv_bias/implicit_batched_gemm_float32_nchw_fma.cpp b/dnn/src/cuda/conv_bias/implicit_batched_gemm_float32_nchw_fma.cpp new file mode 100644 index 00000000..9d5df71d --- /dev/null +++ b/dnn/src/cuda/conv_bias/implicit_batched_gemm_float32_nchw_fma.cpp @@ -0,0 +1,95 @@ +/** + * \file dnn/src/cuda/conv_bias/implicit_batched_gemm_float32_nchw_fma.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/common/conv_bias.h" +#include "src/cuda/conv_bias/algo.h" +#include "src/cuda/convolution_helper/parameter.cuh" +#include "src/cuda/utils.h" + +using namespace megdnn; +using namespace cuda; + +bool ConvBiasForwardImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm::is_available( + const SizeArgs& args) const { +#define RETURN_IF_FALSE(stmt_) \ + if (!(stmt_)) \ + return false; + RETURN_IF_FALSE( + args.src_layout->is_contiguous() && args.dst_layout->is_contiguous()); + using Param = param::ConvBias; + using Format = Param::Format; + using Sparse = Param::Sparse; + using Mode = Param::Mode; + auto&& param = args.opr->param(); + auto&& fm = args.filter_meta; + RETURN_IF_FALSE( + param.format == Format::NCHW && + args.src_layout->dtype.enumv() == DTypeEnum::Float32 && + args.filter_layout->dtype.enumv() == DTypeEnum::Float32 && + args.dst_layout->dtype.enumv() == DTypeEnum::Float32); + RETURN_IF_FALSE( + args.bias_layout->ndim <= 0 || + (args.bias_layout->dtype.enumv() == DTypeEnum::Float32 && + check_bias_share_in_channel(*args.bias_layout, param.format))); + RETURN_IF_FALSE( + args.z_layout->ndim <= 0 || + args.z_layout->dtype.enumv() == DTypeEnum::Float32); + RETURN_IF_FALSE(param.sparse == Sparse::GROUP); + RETURN_IF_FALSE(param.mode == Mode::CROSS_CORRELATION); + // check if channelwise convolution + RETURN_IF_FALSE(fm.icpg == 1 && fm.ocpg == 1); + const auto* op = get_cutlass_conv_op( + args, ConvOperator::kFprop, ConvType::kDepthwiseConvolution, false, false); + RETURN_IF_FALSE(op != nullptr); + return true; +#undef RETURN_IF_FALSE +} + +void ConvBiasForwardImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm::exec( + const ExecArgs& args) const { + auto&& param = args.opr->param(); + auto&& fm = args.filter_meta; + size_t n = args.src_layout->operator[](0), hi = args.src_layout->operator[](2), + wi = args.src_layout->operator[](3); + size_t ho = args.dst_layout->operator[](2), wo = args.dst_layout->operator[](3); + size_t co = fm.group; + size_t ci = co; + // check if channelwise convolution + megdnn_assert(fm.icpg == 1 && fm.ocpg == 1); + auto&& stream = cuda_stream(args.opr->handle()); + + float alpha = 1.f; + float beta = args.bias_layout->ndim > 0 ? 1.f : 0.f; + void* bias_ptr = args.bias_layout->ndim > 0 ? args.bias_tensor->raw_ptr() : nullptr; + float gamma = args.z_layout->ndim > 0 ? 1.f : 0.f; + void* z_ptr = args.z_layout->ndim > 0 ? args.z_tensor->raw_ptr() : nullptr; + + // dummy parameters, used for quantization cases + float theta = 0.f; + float delta = 0.f; + float threshold = 0.f; + + const auto* op = get_cutlass_conv_op( + args, ConvOperator::kFprop, ConvType::kDepthwiseConvolution, false, false); + + UNPACK_CONV_PARAMETER(fm, param); + MARK_USED_VAR + execute_cutlass_conv_op( + op, args.src_tensor->raw_ptr(), args.filter_tensor->raw_ptr(), bias_ptr, + z_ptr, args.dst_tensor->raw_ptr(), nullptr, n, hi, wi, ci, co, fh, fw, ho, + wo, ph, pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, + &threshold, nullptr, stream, nullptr, fm.group); + + after_kernel_launch(); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/conv_bias/opr_impl.h b/dnn/src/cuda/conv_bias/opr_impl.h index 3bb2b6fb..a7184243 100644 --- a/dnn/src/cuda/conv_bias/opr_impl.h +++ b/dnn/src/cuda/conv_bias/opr_impl.h @@ -71,6 +71,9 @@ public: class AlgoInt4Int4NHWCIMMAImplicitGemm; class AlgoUInt4Int4NHWCIMMAImplicitGemm; class AlgoBFloat16; + // The following algorithms are suitable for channel wise convolution + class AlgoFloat32NCHWFMAImplicitBatchedGemm; + class AlgoFloat16NCHWHMMAImplicitBatchedGemm; class AlgoPack; diff --git a/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp b/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp index af063c85..1f01859c 100644 --- a/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp +++ b/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp @@ -39,6 +39,7 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm:: LayoutTypeID::kTensorNC4HW4, NumericTypeID::kS32, LayoutTypeID::kTensorNC4HW4, + NumericTypeID::kS32, cutlass::conv::ConvType::kConvolution, m_algo_param.threadblock_m, m_algo_param.threadblock_n, @@ -52,6 +53,8 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm:: cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp, m_algo_param.stage, special_optimization, + 4, + 16, false}; return (void*)Singleton::get().operation_table.find_op(key); } diff --git a/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw_dp4a.cpp b/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw_dp4a.cpp index 1349ae32..eb93f77a 100644 --- a/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw_dp4a.cpp +++ b/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw_dp4a.cpp @@ -39,6 +39,7 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm:: LayoutTypeID::kTensorNC4HW4, NumericTypeID::kS32, LayoutTypeID::kTensorNC4HW4, + NumericTypeID::kS32, cutlass::conv::ConvType::kConvolution, 16, 64, @@ -52,6 +53,8 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm:: cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp, 2, special_optimization, + 4, + 4, false}; return (void*)Singleton::get().operation_table.find_op(key); } diff --git a/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nhwc_imma.cpp b/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nhwc_imma.cpp index 04606fb2..7f71e9e8 100644 --- a/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nhwc_imma.cpp +++ b/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nhwc_imma.cpp @@ -50,6 +50,7 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::get_avail LayoutTypeID::kTensorNHWC, NumericTypeID::kS32, LayoutTypeID::kTensorNHWC, + NumericTypeID::kS32, cutlass::conv::ConvType::kConvolution, m_algo_param.threadblock_m, m_algo_param.threadblock_n, @@ -63,6 +64,8 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::get_avail cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp, m_algo_param.stage, special_optimization, + m_algo_param.access_size, + m_algo_param.access_size, false}; return (void*)Singleton::get().operation_table.find_op(key); } diff --git a/dnn/src/cuda/cutlass/initialize_all.cu b/dnn/src/cuda/cutlass/initialize_all.cu index 137f78fa..3a43f8de 100644 --- a/dnn/src/cuda/cutlass/initialize_all.cu +++ b/dnn/src/cuda/cutlass/initialize_all.cu @@ -54,24 +54,28 @@ namespace library { void initialize_all_gemm_simt_operations(Manifest& manifest); void initialize_all_conv2d_simt_operations(Manifest& manifest); void initialize_all_deconv_simt_operations(Manifest& manifest); +void initialize_all_dwconv2d_fprop_simt_operations(Manifest& manifest); #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) && CUTLASS_ARCH_MMA_SM75_SUPPORTED void initialize_all_gemm_tensorop884_operations(Manifest& manifest); void initialize_all_gemm_tensorop1688_operations(Manifest& manifest); void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest); void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest); void initialize_all_deconv_tensorop8816_operations(Manifest& manifest); +void initialize_all_dwconv2d_fprop_tensorop884_operations(Manifest& manifest); #endif void initialize_all(Manifest& manifest) { initialize_all_gemm_simt_operations(manifest); initialize_all_conv2d_simt_operations(manifest); initialize_all_deconv_simt_operations(manifest); + initialize_all_dwconv2d_fprop_simt_operations(manifest); #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) && CUTLASS_ARCH_MMA_SM75_SUPPORTED initialize_all_gemm_tensorop884_operations(manifest); initialize_all_gemm_tensorop1688_operations(manifest); initialize_all_conv2d_tensorop8816_operations(manifest); initialize_all_conv2d_tensorop8832_operations(manifest); initialize_all_deconv_tensorop8816_operations(manifest); + initialize_all_dwconv2d_fprop_tensorop884_operations(manifest); #endif } diff --git a/dnn/src/cuda/cutlass/library.h b/dnn/src/cuda/cutlass/library.h index dbc841e6..77c7b91c 100644 --- a/dnn/src/cuda/cutlass/library.h +++ b/dnn/src/cuda/cutlass/library.h @@ -223,6 +223,9 @@ enum class ThreadblockSwizzleID { kConvolutionFpropTrans, kConvolutionDgradNCxHWx, kConvolutionDgradTrans, + kDepthwiseConvolutionFprop, + kDepthwiseConvolutionDgrad, + kDepthwiseConvolutionWgrad, kInvalid }; diff --git a/dnn/src/cuda/cutlass/library_internal.h b/dnn/src/cuda/cutlass/library_internal.h index bd698fae..e1e3dade 100644 --- a/dnn/src/cuda/cutlass/library_internal.h +++ b/dnn/src/cuda/cutlass/library_internal.h @@ -570,6 +570,27 @@ struct ThreadblockSwizzleMap< ThreadblockSwizzleID::kConvolutionDgradTrans; }; +template <> +struct ThreadblockSwizzleMap< + conv::threadblock::DepthwiseConvolutionFpropThreadblockSwizzle> { + static ThreadblockSwizzleID const kId = + ThreadblockSwizzleID::kDepthwiseConvolutionFprop; +}; + +template <> +struct ThreadblockSwizzleMap< + conv::threadblock::DepthwiseConvolutionDgradThreadblockSwizzle> { + static ThreadblockSwizzleID const kId = + ThreadblockSwizzleID::kDepthwiseConvolutionDgrad; +}; + +template <> +struct ThreadblockSwizzleMap< + conv::threadblock::DepthwiseConvolutionWgradThreadblockSwizzle> { + static ThreadblockSwizzleID const kId = + ThreadblockSwizzleID::kDepthwiseConvolutionWgrad; +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// template diff --git a/dnn/src/cuda/cutlass/operation_table.cpp b/dnn/src/cuda/cutlass/operation_table.cpp index 290d2aef..7762be04 100644 --- a/dnn/src/cuda/cutlass/operation_table.cpp +++ b/dnn/src/cuda/cutlass/operation_table.cpp @@ -99,6 +99,8 @@ ConvolutionKey get_convolution_key_from_desc(const ConvolutionDescription& desc) key.layout_dst = desc.dst.layout; key.element_bias = desc.bias.element; key.layout_bias = desc.bias.layout; + key.element_accumulator = + desc.tile_description.math_instruction.element_accumulator; key.convolution_type = desc.convolution_type; @@ -124,6 +126,8 @@ ConvolutionKey get_convolution_key_from_desc(const ConvolutionDescription& desc) key.stages = desc.tile_description.threadblock_stages; key.special_optimization = desc.special_optimization; + key.alignment_src = desc.src.alignment; + key.alignment_filter = desc.filter.alignment; key.without_shared_load = desc.without_shared_load; return key; diff --git a/dnn/src/cuda/cutlass/operation_table.h b/dnn/src/cuda/cutlass/operation_table.h index 3e38bf95..1d4f3bcb 100644 --- a/dnn/src/cuda/cutlass/operation_table.h +++ b/dnn/src/cuda/cutlass/operation_table.h @@ -188,6 +188,7 @@ struct ConvolutionKey { library::LayoutTypeID layout_dst; library::NumericTypeID element_bias; library::LayoutTypeID layout_bias; + NumericTypeID element_accumulator; conv::ConvType convolution_type; @@ -206,6 +207,10 @@ struct ConvolutionKey { epilogue::EpilogueType epilogue_type; int stages; conv::SpecialOptimizeDesc special_optimization; + + int alignment_src; + int alignment_filter; + bool without_shared_load; inline bool operator==(ConvolutionKey const& rhs) const { @@ -215,6 +220,7 @@ struct ConvolutionKey { (layout_filter == rhs.layout_filter) && (element_dst == rhs.element_dst) && (layout_dst == rhs.layout_dst) && (element_bias == rhs.element_bias) && (layout_bias == rhs.layout_bias) && + (element_accumulator == rhs.element_accumulator) && (convolution_type == rhs.convolution_type) && (threadblock_shape_m == rhs.threadblock_shape_m) && (threadblock_shape_n == rhs.threadblock_shape_n) && @@ -227,6 +233,8 @@ struct ConvolutionKey { (instruction_shape_k == rhs.instruction_shape_k) && (epilogue_type == rhs.epilogue_type) && (stages == rhs.stages) && (special_optimization == rhs.special_optimization) && + (alignment_src == rhs.alignment_src) && + (alignment_filter == rhs.alignment_filter) && (without_shared_load == rhs.without_shared_load); } @@ -254,6 +262,7 @@ struct ConvolutionKey { "\n layout_dst: " + to_string(layout_dst) + "\n element_bias: " + to_string(element_bias) + "\n layout_bias: " + to_string(layout_bias) + + "\n element_accumulator: " + to_string(element_accumulator) + "\n convolution_type: " + to_string(convolution_type) + "\n threadblock_shape: " + threadblock_shape_str + "\n warp_shape: " + warp_shape_str + @@ -261,6 +270,8 @@ struct ConvolutionKey { "\n epilogue_type: " + to_string(epilogue_type) + "\n stages: " + std::to_string(stages) + "\n special_optimization: " + to_string(special_optimization) + + "\n alignment_src: " + std::to_string(alignment_src) + + "\n alignment_filter: " + std::to_string(alignment_filter) + "\n without_shared_load: " + to_string(without_shared_load) + "\n}"; } }; @@ -278,6 +289,7 @@ struct ConvolutionKeyHasher { .update(&key.layout_dst, sizeof(key.layout_dst)) .update(&key.element_bias, sizeof(key.element_bias)) .update(&key.layout_bias, sizeof(key.layout_bias)) + .update(&key.element_accumulator, sizeof(key.element_accumulator)) .update(&key.convolution_type, sizeof(key.convolution_type)) .update(&key.threadblock_shape_m, sizeof(key.threadblock_shape_m)) .update(&key.threadblock_shape_n, sizeof(key.threadblock_shape_n)) @@ -291,6 +303,8 @@ struct ConvolutionKeyHasher { .update(&key.epilogue_type, sizeof(key.epilogue_type)) .update(&key.stages, sizeof(key.stages)) .update(&key.special_optimization, sizeof(key.special_optimization)) + .update(&key.alignment_src, sizeof(key.alignment_src)) + .update(&key.alignment_filter, sizeof(key.alignment_filter)) .update(&key.without_shared_load, sizeof(key.without_shared_load)) .digest(); } diff --git a/dnn/test/cuda/chanwise_convolution.cpp b/dnn/test/cuda/chanwise_convolution.cpp index 8159d4b6..8ad2160e 100644 --- a/dnn/test/cuda/chanwise_convolution.cpp +++ b/dnn/test/cuda/chanwise_convolution.cpp @@ -38,8 +38,10 @@ bool check_need_full_bench() { } #endif -Convolution::Param gconv_param(Convolution::Param p) { +Convolution::Param gconv_param(Convolution::Param p, bool io16xc32 = false) { p.sparse = Convolution::Param::Sparse::GROUP; + if (io16xc32) + p.compute_mode = Convolution::Param::ComputeMode::FLOAT32; return p; } @@ -421,6 +423,129 @@ TEST_F(CUDA, CHANWISE_CONVOLUTION_BACKWARD_FILTER) { } } +namespace { +template +struct AlgoCheckerMaker { + static auto make(const char* name, bool* require_algo) { + return AlgoChecker(name, require_algo); + } +}; + +template <> +struct AlgoCheckerMaker { + static auto make(const char* name, bool* require_algo) { + return AlgoChecker( + ExecutionPolicyAlgoName{ + "DEFAULT", + {{ConvBiasForward::algo_name( + name, {}) + .c_str(), + {}}}}, + require_algo); + } +}; + +template +void check_chanwise(DType io_type, DType comp_type, Handle* handle, const char* name) { + Checker checker(handle); + bool require_algo = false; + checker.set_before_exec_callback(AlgoCheckerMaker::make(name, &require_algo)); + checker.set_dtype(0, io_type).set_dtype(1, io_type).set_dtype(2, io_type); + bool io16xc32 = false; + if (io_type == dtype::Float16()) { + if (comp_type == dtype::Float16()) { + checker.set_epsilon(1e-1); + } else { + io16xc32 = true; + } + } + // dispatch testcase by operation + if (std::is_same::value) { + // align 8 + checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32)) + .execs({{8, 2, 16, 16}, {2, 1, 1, 15, 15}, {}}); + // align 1 + checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32)) + .execs({{8, 2, 15, 15}, {2, 1, 1, 15, 15}, {}}); + // align 2 + checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32)) + .execs({{8, 2, 14, 14}, {2, 1, 1, 15, 15}, {}}); + // custom padding + checker.set_param(gconv_param({M, 3, 3, 1, 1}, io16xc32)) + .execs({{8, 2, 16, 16}, {2, 1, 1, 15, 15}, {}}); + // custom stride + checker.set_param(gconv_param({M, 7, 7, 2, 2}, io16xc32)) + .execs({{8, 2, 16, 16}, {2, 1, 1, 15, 15}, {}}); + } else if (std::is_same::value) { + // align 8 + checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32)) + .execs({{2, 1, 1, 15, 15}, {8, 2, 16, 16}, {8, 2, 16, 16}}); + // align 1 + checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32)) + .execs({{2, 1, 1, 15, 15}, {8, 2, 15, 15}, {8, 2, 15, 15}}); + // align 2 + checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32)) + .execs({{2, 1, 1, 15, 15}, {8, 2, 14, 14}, {8, 2, 14, 14}}); + // custom padding + checker.set_param(gconv_param({M, 3, 3, 1, 1}, io16xc32)) + .execs({{2, 1, 1, 15, 15}, {8, 2, 8, 8}, {8, 2, 16, 16}}); + // custom stride + checker.set_param(gconv_param({M, 7, 7, 2, 2}, io16xc32)) + .execs({{2, 1, 1, 15, 15}, {8, 2, 7, 7}, {8, 2, 14, 14}}); + } else if (std::is_same::value) { + } +} +} // namespace + +#define MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_FMA_KERNEL(cb) \ + cb(1, 128, 128, 8, 32, 64, 8); \ + cb(2, 128, 64, 8, 64, 32, 8); \ + cb(3, 128, 32, 8, 64, 32, 8); \ + cb(4, 64, 128, 8, 64, 32, 8); \ + cb(5, 32, 128, 8, 32, 64, 8); \ + cb(6, 64, 64, 8, 64, 32, 8); \ + cb(7, 32, 64, 8, 32, 64, 8); \ + cb(8, 32, 32, 8, 32, 32, 8); \ + cb(9, 64, 32, 8, 64, 32, 8); + +#define cb(tag, tbm, tbn, tbk, wm, wn, wk) \ + TEST_F(CUDA, CHANWISE_CONVOLUTION_FORWARD_CUTLASS_FMA_##tag) { \ + check_chanwise( \ + dtype::Float32(), dtype::Float32(), handle_cuda(), \ + "FLOAT32_NCHW_FMA_IMPLICIT_BATCHED_GEMM_" #tbm "X" #tbn "X" #tbk \ + "_" #wm "X" #wn "X" #wk "_2stage"); \ + } + +MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_FMA_KERNEL(cb) + +#undef cb +#undef MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_FMA_KERNEL + +#define MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_HMMA_KERNEL(cb) \ + cb(1, 128, 128, 32, 32, 32, 32); \ + cb(2, 128, 256, 32, 64, 64, 32); \ + cb(3, 128, 64, 32, 32, 32, 32); \ + cb(4, 64, 128, 32, 32, 32, 32); \ + cb(5, 64, 64, 32, 32, 32, 32); + +// check both ioc16 and io16xc32 +#define cb(tag, tbm, tbn, tbk, wm, wn, wk) \ + TEST_F(CUDA, CHANWISE_CONVOLUTION_FORWARD_CUTLASS_HMMA_##tag) { \ + check_chanwise( \ + dtype::Float16(), dtype::Float16(), handle_cuda(), \ + "FLOAT16_NCHW_HMMA_IMPLICIT_BATCHED_GEMM_" #tbm "X" #tbn "X" #tbk \ + "_" #wm "X" #wn "X" #wk "_2stage"); \ + check_chanwise( \ + dtype::Float16(), dtype::Float32(), handle_cuda(), \ + "FLOAT16_NCHW_HMMA_IMPLICIT_BATCHED_GEMM_" #tbm "X" #tbn "X" #tbk \ + "_" #wm "X" #wn "X" #wk "_2stage"); \ + } + +MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_HMMA_KERNEL(cb) + +#undef cb +#undef MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_HMMA_KERNEL + #if MEGDNN_WITH_BENCHMARK TEST_F(CUDA, CHANWISE_CONVOLUTION_FORWARD_BENCH_CHECK) { auto handle = handle_cuda(); @@ -1123,6 +1248,82 @@ TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_BWD_FILTER) { // clang-format on } +TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_FORWARD_LARGE_KERNEL) { + CUBenchmarker bencher(handle_cuda()); + size_t RUNS = 100; + bencher.set_display(false).set_times(RUNS); + std::unique_ptr> proxy{ + new OprProxy{true}}; + bencher.set_proxy(proxy); + + Convolution::Param param; + param.format = ConvBias::Param::Format::NCHW; + param.sparse = Convolution::Param::Sparse::GROUP; + NormalRNG rng; + + auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) { + param.pad_h = f / 2; + param.pad_w = f / 2; + param.stride_h = s; + param.stride_w = s; + param.compute_mode = param::Convolution::ComputeMode::DEFAULT; + + TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f}; + + TensorLayout dst_layout; + auto opr = handle_cuda()->create_operator(); + opr->param() = param; + opr->deduce_layout( + {src, dtype::Float32()}, {filter, dtype::Float32()}, dst_layout); + float bandwith = static_cast( + src.total_nr_elems() + filter.total_nr_elems() + + dst_layout.total_nr_elems()) / + (1024 * 1024 * 1024) * 1e3; + + bencher.set_param(param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()) + .set_rng(0, &rng) + .set_rng(1, &rng); + bencher.proxy()->target_execution_policy = {}; + auto time_in_ms_fp32 = bencher.execs({src, filter, {}}) / RUNS; + + bencher.set_param(param) + .set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_dtype(2, dtype::Float16()) + .set_rng(0, &rng) + .set_rng(1, &rng); + bencher.proxy()->target_execution_policy = {}; + auto time_in_ms_fp16 = bencher.execs({src, filter, {}}) / RUNS; + + bencher.proxy()->target_execution_policy.algo.reset(); + param.compute_mode = param::Convolution::ComputeMode::FLOAT32; + bencher.set_param(param); + auto time_in_ms_pseudo_fp16 = bencher.execs({src, filter, {}}) / RUNS; + + printf("stride=%zu src=%s, filter=%s, float32: %.2fms %.2fGB/s " + "float16: %.2fms %.2fGB/s " + "pseudo float16: %.2fms %.2fGB/s " + "speedup: " + "%0.2f (fp16/fp32) %.2f (fp16/pseudo fp16)\n", + s, src.to_string().c_str(), filter.to_string().c_str(), time_in_ms_fp32, + bandwith * 4 / time_in_ms_fp32, time_in_ms_fp16, + bandwith * 2 / time_in_ms_fp16, time_in_ms_pseudo_fp16, + bandwith * 2 / time_in_ms_pseudo_fp16, time_in_ms_fp32 / time_in_ms_fp16, + time_in_ms_pseudo_fp16 / time_in_ms_fp16); + }; + + // clang-format off + for (size_t b : {32, 64}) + for (size_t f : {3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}) { + run(b, 384, 32, 32, f, 1); + run(b, 384, 64, 64, f, 1); + } + // clang-format on +} + #endif // vim: syntax=cpp.doxygen