GitOrigin-RevId: 0b5574f526
tags/v1.6.0-rc1
@@ -20,7 +20,7 @@ class Conv2dOperation: | |||||
# | # | ||||
def __init__(self, conv_kind, conv_type, arch, tile_description, src, flt, bias, dst, element_epilogue, \ | def __init__(self, conv_kind, conv_type, arch, tile_description, src, flt, bias, dst, element_epilogue, \ | ||||
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4, \ | epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4, \ | ||||
need_load_from_const = True, implicit_gemm_mode = ImplicitGemmMode.GemmNt): | |||||
need_load_from_const = True, implicit_gemm_mode = ImplicitGemmMode.GemmNT, without_shared_load = False): | |||||
self.operation_kind = OperationKind.Conv2d | self.operation_kind = OperationKind.Conv2d | ||||
self.conv_kind = conv_kind | self.conv_kind = conv_kind | ||||
@@ -36,6 +36,7 @@ class Conv2dOperation: | |||||
self.swizzling_functor = swizzling_functor | self.swizzling_functor = swizzling_functor | ||||
self.need_load_from_const = need_load_from_const | self.need_load_from_const = need_load_from_const | ||||
self.implicit_gemm_mode = implicit_gemm_mode | self.implicit_gemm_mode = implicit_gemm_mode | ||||
self.without_shared_load = without_shared_load | |||||
# | # | ||||
def accumulator_type(self): | def accumulator_type(self): | ||||
accum = self.tile_description.math_instruction.element_accumulator | accum = self.tile_description.math_instruction.element_accumulator | ||||
@@ -58,11 +59,15 @@ class Conv2dOperation: | |||||
unity_kernel = '' | unity_kernel = '' | ||||
if not self.need_load_from_const: | if not self.need_load_from_const: | ||||
unity_kernel = '_1x1' | |||||
unity_kernel = '_1x1' | |||||
return "%s%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \ | |||||
reorder_k = '' | |||||
if self.without_shared_load: | |||||
reorder_k = '_roc' | |||||
return "%s%s%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \ | |||||
inst_shape, intermediate_type, ConvKindNames[self.conv_kind], unity_kernel, \ | inst_shape, intermediate_type, ConvKindNames[self.conv_kind], unity_kernel, \ | ||||
ShortEpilogueNames[self.epilogue_functor]) | |||||
reorder_k, ShortEpilogueNames[self.epilogue_functor]) | |||||
# | # | ||||
def extended_name(self): | def extended_name(self): | ||||
@@ -177,7 +182,8 @@ using Convolution = | |||||
${alignment_filter}, | ${alignment_filter}, | ||||
${nonuninity_kernel}, | ${nonuninity_kernel}, | ||||
${math_operator}, | ${math_operator}, | ||||
${implicit_gemm_mode}>; | |||||
${implicit_gemm_mode}, | |||||
${without_shared_load}>; | |||||
""" | """ | ||||
@@ -219,7 +225,8 @@ using Convolution = | |||||
'alignment_filter': str(operation.flt.alignment), | 'alignment_filter': str(operation.flt.alignment), | ||||
'nonuninity_kernel': str(operation.need_load_from_const).lower(), | 'nonuninity_kernel': str(operation.need_load_from_const).lower(), | ||||
'math_operator': MathOperationTag[operation.tile_description.math_instruction.math_operation], | 'math_operator': MathOperationTag[operation.tile_description.math_instruction.math_operation], | ||||
'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode] | |||||
'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode], | |||||
'without_shared_load': str(operation.without_shared_load).lower() | |||||
} | } | ||||
return SubstituteTemplate(self.template, values) | return SubstituteTemplate(self.template, values) | ||||
@@ -312,13 +319,13 @@ using Deconvolution = | |||||
# | # | ||||
def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_layout, dst_type, min_cc, src_align = 32, flt_align = 32, dst_align = 128, \ | def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_layout, dst_type, min_cc, src_align = 32, flt_align = 32, dst_align = 128, \ | ||||
skip_unity_kernel = False, implicit_gemm_mode = ImplicitGemmMode.GemmNt): | |||||
skip_unity_kernel = False, implicit_gemm_mode = ImplicitGemmMode.GemmNT, without_shared_load = False): | |||||
operations = [] | operations = [] | ||||
element_epilogue = DataType.f32 | element_epilogue = DataType.f32 | ||||
if conv_kind == ConvKind.Fprop: | if conv_kind == ConvKind.Fprop: | ||||
if src_layout == LayoutType.TensorNHWC: | |||||
swizzling_functor = SwizzlingFunctor.ConvFpropNHWC | |||||
if implicit_gemm_mode == ImplicitGemmMode.GemmTN: | |||||
swizzling_functor = SwizzlingFunctor.ConvFpropTrans | |||||
else: | else: | ||||
swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx | swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx | ||||
else: | else: | ||||
@@ -399,10 +406,10 @@ def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_lay | |||||
bias = TensorDescription(bias_type, dst_layout, max(1, int(32 / DataTypeSize[bias_type]))) | bias = TensorDescription(bias_type, dst_layout, max(1, int(32 / DataTypeSize[bias_type]))) | ||||
dst = TensorDescription(dst_type, dst_layout, int(dst_align / DataTypeSize[dst_type])) | dst = TensorDescription(dst_type, dst_layout, int(dst_align / DataTypeSize[dst_type])) | ||||
new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, True, implicit_gemm_mode) | |||||
new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, True, implicit_gemm_mode, without_shared_load) | |||||
operations.append(new_operation) | operations.append(new_operation) | ||||
if not skip_unity_kernel: | if not skip_unity_kernel: | ||||
new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, False, implicit_gemm_mode) | |||||
new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, False, implicit_gemm_mode, without_shared_load) | |||||
operations.append(new_operation) | operations.append(new_operation) | ||||
return operations | return operations | ||||
@@ -175,12 +175,10 @@ def GenerateConv2d_Simt(args): | |||||
TileDescription([128, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), | TileDescription([128, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), | ||||
TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | ||||
TileDescription([ 64, 128, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc), | TileDescription([ 64, 128, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc), | ||||
TileDescription([ 64, 64, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), | TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), | ||||
TileDescription([ 32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc), | TileDescription([ 32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc), | ||||
TileDescription([ 32, 64, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), | TileDescription([ 32, 64, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), | ||||
TileDescription([ 64, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), | TileDescription([ 64, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), | ||||
TileDescription([ 32, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([ 16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc), | TileDescription([ 16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc), | ||||
TileDescription([ 16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), | TileDescription([ 16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), | ||||
] | ] | ||||
@@ -223,28 +221,36 @@ def GenerateConv2d_TensorOp_8816(args): | |||||
for dst_type, dst_layout in zip(dst_types, dst_layouts): | for dst_type, dst_layout in zip(dst_types, dst_layouts): | ||||
if dst_layout == LayoutType.TensorNC32HW32: | if dst_layout == LayoutType.TensorNC32HW32: | ||||
tile_descriptions = [ | tile_descriptions = [ | ||||
TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), | TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), | ||||
TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | ||||
TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | ||||
TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([ 32, 64, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([ 64, 128, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([ 32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc), | |||||
] | ] | ||||
operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | |||||
dst_layout, dst_type, min_cc, 128, 128, 64, | |||||
False, ImplicitGemmMode.GemmTN, True) | |||||
else: | else: | ||||
assert dst_layout == LayoutType.TensorNC4HW4 | assert dst_layout == LayoutType.TensorNC4HW4 | ||||
tile_descriptions = [ | tile_descriptions = [ | ||||
TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), | TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), | ||||
TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | ||||
TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | ||||
TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([ 32, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([ 64, 128, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([ 32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc), | |||||
] | ] | ||||
operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | |||||
operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | |||||
dst_layout, dst_type, min_cc, 128, 128, 64, | dst_layout, dst_type, min_cc, 128, 128, 64, | ||||
False) | False) | ||||
return operations | return operations | ||||
def GenerateConv2d_TensorOp_8832(args): | def GenerateConv2d_TensorOp_8832(args): | ||||
@@ -279,12 +285,14 @@ def GenerateConv2d_TensorOp_8832(args): | |||||
for dst_layout in dst_layouts: | for dst_layout in dst_layouts: | ||||
dst_type = math_inst.element_b | dst_type = math_inst.element_b | ||||
tile_descriptions = [ | tile_descriptions = [ | ||||
TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), | TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), | ||||
TileDescription([128, 64, 128], 2, [2, 1, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||||
] | ] | ||||
operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | ||||
dst_layout, dst_type, min_cc, 128, 128, 64, | dst_layout, dst_type, min_cc, 128, 128, 64, | ||||
True) | |||||
True, ImplicitGemmMode.GemmTN, True) | |||||
layouts_nhwc = [ | layouts_nhwc = [ | ||||
(LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32), | (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32), | ||||
@@ -299,14 +307,21 @@ def GenerateConv2d_TensorOp_8832(args): | |||||
for math_inst in math_instructions: | for math_inst in math_instructions: | ||||
for layout in layouts_nhwc: | for layout in layouts_nhwc: | ||||
for dst_layout in dst_layouts_nhwc: | for dst_layout in dst_layouts_nhwc: | ||||
dst_type = math_inst.element_b | |||||
tile_descriptions = [ | |||||
TileDescription([128, 32, 64], 2, [2, 1, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([128, 64, 64], 2, [2, 1, 1], math_inst, min_cc, max_cc), | |||||
] | |||||
operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | |||||
dst_layout, dst_type, min_cc, layout[2], layout[2], 32, | |||||
False, ImplicitGemmMode.GemmTn) | |||||
dst_type = math_inst.element_b | |||||
tile_descriptions = [ | |||||
TileDescription([128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||||
] | |||||
for tile in tile_descriptions: | |||||
operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], | |||||
dst_layout, dst_type, min_cc, layout[2], layout[2], 32, | |||||
False, ImplicitGemmMode.GemmTN, False) | |||||
if tile.threadblock_shape[1] == 32 or tile.threadblock_shape[1] == 64: | |||||
dst_align = 32 if tile.threadblock_shape[1] == 32 else 64 | |||||
operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], | |||||
dst_layout, dst_type, min_cc, layout[2], layout[2], dst_align, | |||||
False, ImplicitGemmMode.GemmTN, True) | |||||
return operations | return operations | ||||
def GenerateDeconv_Simt(args): | def GenerateDeconv_Simt(args): | ||||
@@ -649,3 +664,4 @@ if __name__ == "__main__": | |||||
# | # | ||||
################################################################################################### | ################################################################################################### | ||||
@@ -464,10 +464,10 @@ EpilogueFunctorTag = { | |||||
ShortEpilogueNames = { | ShortEpilogueNames = { | ||||
EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: 'hswish', | EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: 'hswish', | ||||
EpilogueFunctor.BiasAddLinearCombinationReluClamp: 'relu', | EpilogueFunctor.BiasAddLinearCombinationReluClamp: 'relu', | ||||
EpilogueFunctor.BiasAddLinearCombinationClamp: 'identity', | |||||
EpilogueFunctor.BiasAddLinearCombinationClamp: 'id', | |||||
EpilogueFunctor.BiasAddLinearCombinationHSwish: 'hswish', | EpilogueFunctor.BiasAddLinearCombinationHSwish: 'hswish', | ||||
EpilogueFunctor.BiasAddLinearCombinationRelu: 'relu', | EpilogueFunctor.BiasAddLinearCombinationRelu: 'relu', | ||||
EpilogueFunctor.BiasAddLinearCombination: 'identity', | |||||
EpilogueFunctor.BiasAddLinearCombination: 'id', | |||||
} | } | ||||
@@ -482,7 +482,7 @@ class SwizzlingFunctor(enum.Enum): | |||||
Identity4 = enum_auto() | Identity4 = enum_auto() | ||||
Identity8 = enum_auto() | Identity8 = enum_auto() | ||||
ConvFpropNCxHWx = enum_auto() | ConvFpropNCxHWx = enum_auto() | ||||
ConvFpropNHWC = enum_auto() | |||||
ConvFpropTrans = enum_auto() | |||||
ConvDgradNCxHWx = enum_auto() | ConvDgradNCxHWx = enum_auto() | ||||
# | # | ||||
@@ -492,7 +492,7 @@ SwizzlingFunctorTag = { | |||||
SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', | SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', | ||||
SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', | SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', | ||||
SwizzlingFunctor.ConvFpropNCxHWx: 'cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle', | SwizzlingFunctor.ConvFpropNCxHWx: 'cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle', | ||||
SwizzlingFunctor.ConvFpropNHWC: 'cutlass::conv::threadblock::ConvolutionFpropNHWCThreadblockSwizzle', | |||||
SwizzlingFunctor.ConvFpropTrans: 'cutlass::conv::threadblock::ConvolutionFpropTransThreadblockSwizzle', | |||||
SwizzlingFunctor.ConvDgradNCxHWx: 'cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle', | SwizzlingFunctor.ConvDgradNCxHWx: 'cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle', | ||||
} | } | ||||
@@ -563,17 +563,17 @@ StrideSupportNames = { | |||||
} | } | ||||
class ImplicitGemmMode(enum.Enum): | class ImplicitGemmMode(enum.Enum): | ||||
GemmNt = enum_auto() | |||||
GemmTn = enum_auto() | |||||
GemmNT = enum_auto() | |||||
GemmTN = enum_auto() | |||||
ImplicitGemmModeNames = { | ImplicitGemmModeNames = { | ||||
ImplicitGemmMode.GemmNt: 'gemm_nt', | |||||
ImplicitGemmMode.GemmTn: 'gemm_tn', | |||||
ImplicitGemmMode.GemmNT: 'gemm_nt', | |||||
ImplicitGemmMode.GemmTN: 'gemm_tn', | |||||
} | } | ||||
ImplicitGemmModeTag = { | ImplicitGemmModeTag = { | ||||
ImplicitGemmMode.GemmNt: 'cutlass::conv::ImplicitGemmMode::GEMM_NT', | |||||
ImplicitGemmMode.GemmTn: 'cutlass::conv::ImplicitGemmMode::GEMM_TN', | |||||
ImplicitGemmMode.GemmNT: 'cutlass::conv::ImplicitGemmMode::GEMM_NT', | |||||
ImplicitGemmMode.GemmTN: 'cutlass::conv::ImplicitGemmMode::GEMM_TN', | |||||
} | } | ||||
################################################################################################### | ################################################################################################### | ||||
@@ -164,415 +164,461 @@ cutlass_gen_list = [ | |||||
"cutlass_simt_sgemv_batched_strided_1x32_32_tt_align1x4.cu", | "cutlass_simt_sgemv_batched_strided_1x32_32_tt_align1x4.cu", | ||||
"cutlass_simt_sgemv_batched_strided_1x32_16_tt_align1x2.cu", | "cutlass_simt_sgemv_batched_strided_1x32_16_tt_align1x2.cu", | ||||
"cutlass_simt_sgemv_batched_strided_1x32_8_tt_align1x1.cu", | "cutlass_simt_sgemv_batched_strided_1x32_8_tt_align1x1.cu", | ||||
"cutlass_simt_s8_idgrad_identity_s8_64x128x32_64x32x32_2_nc4hw4_k4rsc4.cu", | |||||
"cutlass_simt_s8_idgrad_identity_s8_32x128x32_32x64x32_2_nc4hw4_k4rsc4.cu", | |||||
"cutlass_simt_s8_idgrad_identity_s8_16x128x16_16x64x16_2_nc4hw4_k4rsc4.cu", | |||||
"cutlass_simt_s8_idgrad_identity_s8_16x128x16_16x128x16_1_nc4hw4_k4rsc4.cu", | |||||
"cutlass_simt_s8_idgrad_identity_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", | |||||
"cutlass_simt_s8_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_idgrad_id_s8_64x128x32_64x32x32_2_nc4hw4_k4rsc4.cu", | |||||
"cutlass_simt_s8_idgrad_id_s8_32x128x32_32x64x32_2_nc4hw4_k4rsc4.cu", | |||||
"cutlass_simt_s8_idgrad_id_s8_16x128x16_16x64x16_2_nc4hw4_k4rsc4.cu", | |||||
"cutlass_simt_s8_idgrad_id_s8_16x128x16_16x128x16_1_nc4hw4_k4rsc4.cu", | |||||
"cutlass_simt_s8_idgrad_id_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", | |||||
"cutlass_simt_s8_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||||
"cutlass_simt_s8_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
"cutlass_simt_s8_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
"cutlass_simt_u4_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_u4_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_u4_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_u4_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_u4_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_u4_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_u4_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_u4_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_u4_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_u4_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_u4_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_u4_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_u4_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_u4_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_u4_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_u4_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_u4_ifprop_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_u4_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_u4_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_s4_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_s4_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_s4_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_s4_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_s4_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_s4_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_s4_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_s4_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_s4_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_s4_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_s4_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_s4_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_s4_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_s4_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_s4_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_s4_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_s4_ifprop_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
"cutlass_simt_s4_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_s4_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
"cutlass_simt_f32_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_1x1_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_1x1_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_1x1_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_1x1_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_1x1_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_1x1_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_1x1_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_1x1_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_1x1_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||||
"cutlass_simt_f32_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_1x1_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_tensorop_s8_i8816fprop_identity_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_identity_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_identity_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_identity_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_identity_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_identity_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_relu_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_identity_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_relu_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_identity_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_identity_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_id_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_id_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_id_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_relu_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_id_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
"cutlass_tensorop_s8_i8816fprop_identity_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_id_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_id_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_id_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_id_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
"cutlass_tensorop_s8_i8816fprop_identity_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_identity_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_id_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_id_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
"cutlass_tensorop_s8_i8816fprop_identity_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_relu_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_identity_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_relu_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_identity_s4_256x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_relu_s4_256x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_256x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_identity_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_identity_u4_s4_256x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_256x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_identity_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_identity_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_identity_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_identity_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_identity_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_identity_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_id_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_id_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_id_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_id_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_relu_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_id_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_id_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_relu_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_id_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_id_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_relu_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_id_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_id_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_id_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_id_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_id_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_id_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_id_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
] | ] |
@@ -217,56 +217,68 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { | |||||
#if CUDA_VERSION >= 10020 | #if CUDA_VERSION >= 10020 | ||||
{ | { | ||||
using AlgoParam = AlgoInt8NCHW32IMMAImplicitGemm::AlgoParam; | using AlgoParam = AlgoInt8NCHW32IMMAImplicitGemm::AlgoParam; | ||||
int8_nchw32_imma.emplace_back(AlgoParam{128, 256, 64, 64, 64, 64}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{256, 128, 64, 64, 64, 64}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{128, 128, 64, 64, 64, 64}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 64, 32, 64, 64}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 64, 64, 32, 64}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{64, 64, 64, 32, 32, 64}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{32, 64, 64, 32, 16, 64}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{128, 256, 64, 64, 64, 64, 2}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{256, 128, 64, 64, 64, 64, 2}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{128, 128, 64, 64, 64, 64, 2}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 64, 64, 32, 64, 2}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 64, 32, 64, 64, 2}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 1}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 32, 32, 64, 32, 1}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 1}); | |||||
} | } | ||||
{ | { | ||||
using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | ||||
int4_int4_nchw64_imma.emplace_back( | int4_int4_nchw64_imma.emplace_back( | ||||
AlgoParam{128, 128, 128, 64, 64, 128}); | |||||
AlgoParam{128, 128, 128, 64, 64, 128, 2}); | |||||
int4_int4_nchw64_imma.emplace_back( | int4_int4_nchw64_imma.emplace_back( | ||||
AlgoParam{256, 128, 128, 64, 64, 128}); | |||||
AlgoParam{128, 256, 128, 64, 64, 128, 2}); | |||||
int4_int4_nchw64_imma.emplace_back( | |||||
AlgoParam{128, 64, 128, 64, 64, 128, 2}); | |||||
int4_int4_nchw64_imma.emplace_back( | |||||
AlgoParam{128, 64, 64, 64, 64, 64, 1}); | |||||
} | } | ||||
{ | { | ||||
using AlgoParam = AlgoUInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | using AlgoParam = AlgoUInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | ||||
uint4_int4_nchw64_imma.emplace_back( | uint4_int4_nchw64_imma.emplace_back( | ||||
AlgoParam{128, 128, 128, 64, 64, 128}); | |||||
AlgoParam{128, 128, 128, 64, 64, 128, 2}); | |||||
uint4_int4_nchw64_imma.emplace_back( | |||||
AlgoParam{128, 256, 128, 64, 64, 128, 2}); | |||||
uint4_int4_nchw64_imma.emplace_back( | uint4_int4_nchw64_imma.emplace_back( | ||||
AlgoParam{256, 128, 128, 64, 64, 128}); | |||||
AlgoParam{128, 64, 128, 64, 64, 128, 2}); | |||||
uint4_int4_nchw64_imma.emplace_back( | |||||
AlgoParam{128, 64, 64, 64, 64, 64, 1}); | |||||
} | } | ||||
{ | { | ||||
using AlgoParam = AlgoInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | using AlgoParam = AlgoInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | ||||
int4_int4_nhwc_imma.emplace_back( | int4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 32, 64, 64, 32, 64, 32}); | |||||
AlgoParam{128, 32, 64, 64, 32, 64, 1, 32}); | |||||
int4_int4_nhwc_imma.emplace_back( | |||||
AlgoParam{128, 32, 64, 64, 32, 64, 1, 16}); | |||||
int4_int4_nhwc_imma.emplace_back( | |||||
AlgoParam{128, 32, 64, 64, 32, 64, 1, 8}); | |||||
int4_int4_nhwc_imma.emplace_back( | int4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 32, 64, 64, 32, 64, 16}); | |||||
int4_int4_nhwc_imma.emplace_back(AlgoParam{128, 32, 64, 64, 32, 64, 8}); | |||||
AlgoParam{128, 64, 64, 64, 64, 64, 1, 32}); | |||||
int4_int4_nhwc_imma.emplace_back( | int4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 64, 64, 64, 64, 64, 32}); | |||||
AlgoParam{128, 64, 64, 64, 64, 64, 1, 16}); | |||||
int4_int4_nhwc_imma.emplace_back( | int4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 64, 64, 64, 64, 64, 16}); | |||||
int4_int4_nhwc_imma.emplace_back(AlgoParam{128, 64, 64, 64, 64, 64, 8}); | |||||
AlgoParam{128, 64, 64, 64, 64, 64, 1, 8}); | |||||
} | } | ||||
{ | { | ||||
using AlgoParam = AlgoUInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | using AlgoParam = AlgoUInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | ||||
uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 32, 64, 64, 32, 64, 32}); | |||||
AlgoParam{128, 32, 64, 64, 32, 64, 1, 32}); | |||||
uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 32, 64, 64, 32, 64, 16}); | |||||
AlgoParam{128, 32, 64, 64, 32, 64, 1, 16}); | |||||
uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 32, 64, 64, 32, 64, 8}); | |||||
AlgoParam{128, 32, 64, 64, 32, 64, 1, 8}); | |||||
uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 64, 64, 64, 64, 64, 32}); | |||||
AlgoParam{128, 64, 64, 64, 64, 64, 1, 32}); | |||||
uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 64, 64, 64, 64, 64, 16}); | |||||
AlgoParam{128, 64, 64, 64, 64, 64, 1, 16}); | |||||
uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 64, 64, 64, 64, 64, 8}); | |||||
AlgoParam{128, 64, 64, 64, 64, 64, 1, 8}); | |||||
} | } | ||||
#endif | #endif | ||||
} | } | ||||
@@ -279,10 +291,8 @@ void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() { | |||||
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32, 2}); | int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32, 2}); | ||||
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 2}); | int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 2}); | ||||
int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 2}); | int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 2}); | ||||
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 64, 32, 64, 32, 32, 2}); | |||||
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32, 2}); | int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32, 2}); | ||||
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32, 2}); | int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32, 2}); | ||||
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 32, 32, 32, 32, 32, 2}); | |||||
int8_nchw4_dotprod.emplace_back(AlgoParam{16, 128, 16, 16, 128, 16, 1}); | int8_nchw4_dotprod.emplace_back(AlgoParam{16, 128, 16, 16, 128, 16, 1}); | ||||
int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 2}); | int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 2}); | ||||
} | } | ||||
@@ -723,6 +723,7 @@ public: | |||||
int warp_m; | int warp_m; | ||||
int warp_n; | int warp_n; | ||||
int warp_k; | int warp_k; | ||||
int stage; | |||||
}; | }; | ||||
AlgoInt8NCHW32IMMAImplicitGemm(AlgoParam algo_param) | AlgoInt8NCHW32IMMAImplicitGemm(AlgoParam algo_param) | ||||
: m_algo_param{algo_param} { | : m_algo_param{algo_param} { | ||||
@@ -770,6 +771,7 @@ public: | |||||
int warp_m; | int warp_m; | ||||
int warp_n; | int warp_n; | ||||
int warp_k; | int warp_k; | ||||
int stage; | |||||
}; | }; | ||||
AlgoInt4NCHW64IMMAImplicitGemmBase(AlgoParam algo_param) | AlgoInt4NCHW64IMMAImplicitGemmBase(AlgoParam algo_param) | ||||
@@ -897,6 +899,7 @@ public: | |||||
int warp_m; | int warp_m; | ||||
int warp_n; | int warp_n; | ||||
int warp_k; | int warp_k; | ||||
int stage; | |||||
int access_size; | int access_size; | ||||
}; | }; | ||||
@@ -38,7 +38,7 @@ void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( | |||||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | const convolution::ConvParam& param, uint32_t nonlinear_mode, | ||||
float alpha, float beta, float gamma, float scale, | float alpha, float beta, float gamma, float scale, | ||||
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | ||||
cudaStream_t stream); | |||||
int stages, cudaStream_t stream); | |||||
template <bool NeedLoadFromConstMem> | template <bool NeedLoadFromConstMem> | ||||
void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( | void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( | ||||
@@ -47,7 +47,7 @@ void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( | |||||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | const convolution::ConvParam& param, uint32_t nonlinear_mode, | ||||
float alpha, float beta, float gamma, float scale, | float alpha, float beta, float gamma, float scale, | ||||
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | ||||
cudaStream_t stream); | |||||
int stages, cudaStream_t stream); | |||||
template <bool NeedLoadFromConstMem> | template <bool NeedLoadFromConstMem> | ||||
void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4( | void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4( | ||||
@@ -83,7 +83,7 @@ void do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | const convolution::ConvParam& param, uint32_t nonlinear_mode, | ||||
float alpha, float beta, float gamma, float scale, | float alpha, float beta, float gamma, float scale, | ||||
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | ||||
cudaStream_t stream); | |||||
int stages, cudaStream_t stream); | |||||
template <bool NeedLoadFromConstMem> | template <bool NeedLoadFromConstMem> | ||||
void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | ||||
@@ -92,7 +92,7 @@ void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | const convolution::ConvParam& param, uint32_t nonlinear_mode, | ||||
float alpha, float beta, float gamma, float delta, float theta, | float alpha, float beta, float gamma, float delta, float theta, | ||||
float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, | float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, | ||||
const GemmCoord& warp_shape, cudaStream_t stream); | |||||
const GemmCoord& warp_shape, int stages, cudaStream_t stream); | |||||
template <bool signedness> | template <bool signedness> | ||||
void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( | void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( | ||||
@@ -110,7 +110,7 @@ void do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | const convolution::ConvParam& param, uint32_t nonlinear_mode, | ||||
float alpha, float beta, float gamma, float scale, | float alpha, float beta, float gamma, float scale, | ||||
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | ||||
const int32_t access_size, cudaStream_t stream); | |||||
const int32_t access_size, int stages, cudaStream_t stream); | |||||
template <bool NeedLoadFromConstMem> | template <bool NeedLoadFromConstMem> | ||||
void do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | void do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | ||||
@@ -119,7 +119,7 @@ void do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | const convolution::ConvParam& param, uint32_t nonlinear_mode, | ||||
float alpha, float beta, float gamma, float delta, float theta, | float alpha, float beta, float gamma, float delta, float theta, | ||||
float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, | float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, | ||||
const GemmCoord& warp_shape, const int32_t access_size, | |||||
const GemmCoord& warp_shape, const int32_t access_size, int stages, | |||||
cudaStream_t stream); | cudaStream_t stream); | ||||
} // namespace cutlass_wrapper | } // namespace cutlass_wrapper | ||||
@@ -0,0 +1,595 @@ | |||||
/** | |||||
* \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
// ignore warning of cutlass | |||||
#pragma GCC diagnostic push | |||||
#pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
#if !MEGDNN_TEGRA_X1 | |||||
#include "cutlass/convolution/device/convolution.h" | |||||
#endif | |||||
#include "src/common/opr_param_defs_enumv.cuh" | |||||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
#pragma GCC diagnostic pop | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
using namespace cutlass_wrapper; | |||||
/* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */ | |||||
#if MEGDNN_TEGRA_X1 | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
int8_t* /* d_dst */, int* /* workspace */, | |||||
const convolution::ConvParam& /* param */, | |||||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
float /* beta */, float /* gamma */, float /* scale */, | |||||
const GemmCoord& /* threadblock_shape */, | |||||
const GemmCoord& /* warp_shape */, int /* stages */, | |||||
cudaStream_t /* stream */) {} | |||||
#else | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
const int8_t* d_src, const int8_t* d_filter, | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
int* workspace, const convolution::ConvParam& param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float scale, const GemmCoord& threadblock_shape, | |||||
const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_, warp_m_, warp_n_, \ | |||||
warp_k_, stage_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | |||||
threadblock_shape.n() == threadblock_n_ && \ | |||||
threadblock_shape.k() == threadblock_k_ && \ | |||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
warp_shape.k() == warp_k_ && stages == stage_) { \ | |||||
using ThreadBlockShape = \ | |||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_>; \ | |||||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
using Convolution = cutlass::conv::device::Convolution< \ | |||||
cutlass::int4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||||
cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||||
ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
cutlass::conv::ConvType::kConvolution, \ | |||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
cutlass::conv::threadblock:: \ | |||||
ConvolutionFpropTransThreadblockSwizzle, \ | |||||
stage_, 32, 32, NeedLoadFromConstMem, \ | |||||
cutlass::arch::OpMultiplyAddSaturate, \ | |||||
cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ | |||||
typename Convolution::ConvolutionParameter conv_param( \ | |||||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
return cutlass_convolution_wrapper<Convolution>( \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||||
reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, \ | |||||
conv_param, epilogue, stream); \ | |||||
} | |||||
#define DISPATCH_KERNEL \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 128, 64, 64, 128, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 128, 64, 64, 128, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1); \ | |||||
megdnn_assert(false, \ | |||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
"(%dx%dx%d)", \ | |||||
threadblock_shape.m(), threadblock_shape.n(), \ | |||||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
warp_shape.k()); | |||||
using ElementOutput = cutlass::int4b_t; | |||||
using ElementAccumulator = int32_t; | |||||
using ElementBias = int32_t; | |||||
using ElementCompute = float; | |||||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
switch (nonlinear_mode) { | |||||
case NonlineMode::IDENTITY: { | |||||
using EpilogueOp = | |||||
cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::RELU: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationReluClamp< | |||||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::H_SWISH: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationHSwishClamp< | |||||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
default: | |||||
megdnn_assert(false, | |||||
"unsupported nonlinear mode for conv bias operator"); | |||||
} | |||||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
#undef DISPATCH_KERNEL | |||||
} | |||||
#endif | |||||
#define INST(need_load_from_const_mem) \ | |||||
template void megdnn::cuda::cutlass_wrapper:: \ | |||||
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||||
need_load_from_const_mem>( \ | |||||
const int8_t* d_src, const int8_t* d_filter, \ | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
int* workspace, const convolution::ConvParam& param, \ | |||||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
float gamma, float scale, \ | |||||
const GemmCoord& threadblock_shape, \ | |||||
const GemmCoord& warp_shape, int stages, \ | |||||
cudaStream_t stream); | |||||
INST(true); | |||||
#undef INST | |||||
/* ====== cutlass kernel wrapper for uint4 x int4 nchw64 layout ====== */ | |||||
#if MEGDNN_TEGRA_X1 | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||||
uint8_t* /* d_dst */, int* /* workspace */, | |||||
const convolution::ConvParam& /* param */, | |||||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
float /* beta */, float /* gamma */, float /* delta */, | |||||
float /* theta */, float /* scale */, | |||||
uint8_t /* src_zero_point */, | |||||
const GemmCoord& /* threadblock_shape */, | |||||
const GemmCoord& /* warp_shape */, int /* stages */, | |||||
cudaStream_t /* stream */) {} | |||||
#else | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
const uint8_t* d_src, const int8_t* d_filter, | |||||
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||||
int* workspace, const convolution::ConvParam& param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float delta, float theta, float /* scale */, | |||||
uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||||
const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_, warp_m_, warp_n_, \ | |||||
warp_k_, stage_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | |||||
threadblock_shape.n() == threadblock_n_ && \ | |||||
threadblock_shape.k() == threadblock_k_ && \ | |||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
warp_shape.k() == warp_k_ && stages == stage_) { \ | |||||
using ThreadBlockShape = \ | |||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_>; \ | |||||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
using Convolution = cutlass::conv::device::Convolution< \ | |||||
cutlass::uint4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||||
cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||||
ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
cutlass::conv::ConvType::kConvolution, \ | |||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
cutlass::conv::threadblock:: \ | |||||
ConvolutionFpropTransThreadblockSwizzle, \ | |||||
stage_, 32, 32, NeedLoadFromConstMem, \ | |||||
cutlass::arch::OpMultiplyAddSaturate, \ | |||||
cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ | |||||
typename Convolution::ConvolutionParameter conv_param( \ | |||||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
return cutlass_convolution_wrapper<Convolution>( \ | |||||
reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||||
reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||||
conv_param, epilogue, stream, {src_zero_point}); \ | |||||
} | |||||
#define DISPATCH_KERNEL \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 128, 64, 64, 128, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 128, 64, 64, 128, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1); \ | |||||
megdnn_assert(false, \ | |||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
"(%dx%dx%d)", \ | |||||
threadblock_shape.m(), threadblock_shape.n(), \ | |||||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
warp_shape.k()); | |||||
using ElementOutput = cutlass::uint4b_t; | |||||
using ElementAccumulator = int32_t; | |||||
using ElementBias = int32_t; | |||||
using ElementCompute = float; | |||||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
switch (nonlinear_mode) { | |||||
case NonlineMode::IDENTITY: { | |||||
using EpilogueOp = | |||||
cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
delta + theta}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::RELU: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationReluClamp< | |||||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
0, delta, theta}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
default: | |||||
megdnn_assert(false, | |||||
"unsupported nonlinear mode for conv bias operator"); | |||||
} | |||||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
#undef DISPATCH_KERNEL | |||||
} | |||||
#endif | |||||
#define INST(need_load_from_const_mem) \ | |||||
template void megdnn::cuda::cutlass_wrapper:: \ | |||||
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||||
need_load_from_const_mem>( \ | |||||
const uint8_t* d_src, const int8_t* d_filter, \ | |||||
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||||
int* workspace, const convolution::ConvParam& param, \ | |||||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
float gamma, float delta, float theta, float scale, \ | |||||
uint8_t src_zero_point, \ | |||||
const GemmCoord& threadblock_shape, \ | |||||
const GemmCoord& warp_shape, int stages, \ | |||||
cudaStream_t stream); | |||||
INST(true); | |||||
#undef INST | |||||
/* ====== cutlass kernel wrapper for int4 x int4 nhwc layout ====== */ | |||||
#if MEGDNN_TEGRA_X1 | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||||
const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
int8_t* /* d_dst */, int* /* workspace */, | |||||
const convolution::ConvParam& /* param */, | |||||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
float /* beta */, float /* gamma */, float /* scale */, | |||||
const GemmCoord& /* threadblock_shape */, | |||||
const GemmCoord& /* warp_shape */, | |||||
const int32_t /* access_size */, int /* stages */, | |||||
cudaStream_t /* stream */) {} | |||||
#else | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||||
const int8_t* d_src, const int8_t* d_filter, | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
int* workspace, const convolution::ConvParam& param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float scale, const GemmCoord& threadblock_shape, | |||||
const GemmCoord& warp_shape, const int32_t access_size, | |||||
int stages, cudaStream_t stream) { | |||||
bool without_shared_load = | |||||
((param.co % threadblock_shape.n() == 0) && | |||||
(threadblock_shape.n() == 32 || threadblock_shape.n() == 64)); | |||||
int out_elements_per_access = | |||||
without_shared_load ? threadblock_shape.n() / 4 : 8; | |||||
#define RUN_CUTLASS_WRAPPER(stage_, access_size_, without_shared_load_) \ | |||||
using Convolution = cutlass::conv::device::Convolution< \ | |||||
cutlass::int4b_t, cutlass::layout::TensorNHWC, cutlass::int4b_t, \ | |||||
cutlass::layout::TensorNCxHWx<access_size_>, ElementOutput, \ | |||||
cutlass::layout::TensorNHWC, int32_t, cutlass::layout::TensorNHWC, \ | |||||
int32_t, cutlass::conv::ConvType::kConvolution, \ | |||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
cutlass::conv::threadblock:: \ | |||||
ConvolutionFpropTransThreadblockSwizzle, \ | |||||
stage_, access_size_, access_size_, NeedLoadFromConstMem, \ | |||||
cutlass::arch::OpMultiplyAddSaturate, \ | |||||
cutlass::conv::ImplicitGemmMode::GEMM_TN, without_shared_load_>; \ | |||||
typename Convolution::ConvolutionParameter conv_param( \ | |||||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
return cutlass_convolution_wrapper<Convolution>( \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||||
reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, conv_param, \ | |||||
epilogue, stream); | |||||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE( \ | |||||
threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ | |||||
warp_k_, stage_, access_size_, out_elements_per_access_, \ | |||||
without_shared_load_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | |||||
threadblock_shape.n() == threadblock_n_ && \ | |||||
threadblock_shape.k() == threadblock_k_ && \ | |||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
warp_shape.k() == warp_k_ && stages == stage_ && \ | |||||
access_size == access_size_ && \ | |||||
out_elements_per_access == out_elements_per_access_ && \ | |||||
without_shared_load == without_shared_load_) { \ | |||||
using ThreadBlockShape = \ | |||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_>; \ | |||||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
using ElementOutput = cutlass::int4b_t; \ | |||||
using ElementAccumulator = int32_t; \ | |||||
using ElementBias = int32_t; \ | |||||
using ElementCompute = float; \ | |||||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; \ | |||||
switch (nonlinear_mode) { \ | |||||
case NonlineMode::IDENTITY: { \ | |||||
using EpilogueOp = cutlass::epilogue::thread:: \ | |||||
BiasAddLinearCombinationClamp< \ | |||||
ElementOutput, out_elements_per_access_, \ | |||||
ElementAccumulator, ElementBias, \ | |||||
ElementCompute>; \ | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma}; \ | |||||
RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||||
without_shared_load_); \ | |||||
} \ | |||||
case NonlineMode::RELU: { \ | |||||
using EpilogueOp = cutlass::epilogue::thread:: \ | |||||
BiasAddLinearCombinationReluClamp< \ | |||||
ElementOutput, out_elements_per_access_, \ | |||||
ElementAccumulator, ElementBias, \ | |||||
ElementCompute>; \ | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; \ | |||||
RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||||
without_shared_load_); \ | |||||
} \ | |||||
case NonlineMode::H_SWISH: { \ | |||||
using EpilogueOp = cutlass::epilogue::thread:: \ | |||||
BiasAddLinearCombinationHSwishClamp< \ | |||||
ElementOutput, out_elements_per_access_, \ | |||||
ElementAccumulator, ElementBias, \ | |||||
ElementCompute>; \ | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ | |||||
scale}; \ | |||||
RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||||
without_shared_load_); \ | |||||
} \ | |||||
default: \ | |||||
megdnn_assert( \ | |||||
false, \ | |||||
"unsupported nonlinear mode for conv bias operator"); \ | |||||
} \ | |||||
} | |||||
#define DISPATCH_KERNEL \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, true); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, true); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, true); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 16, true); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 16, true); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 16, true); \ | |||||
megdnn_assert(false, \ | |||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
"(%dx%dx%d) and access_size (%d)", \ | |||||
threadblock_shape.m(), threadblock_shape.n(), \ | |||||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
warp_shape.k(), access_size); | |||||
DISPATCH_KERNEL; | |||||
#undef RUN_CUTLASS_WRAPPER | |||||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
#undef DISPATCH_KERNEL | |||||
} | |||||
#endif | |||||
#define INST(need_load_from_const_mem) \ | |||||
template void megdnn::cuda::cutlass_wrapper:: \ | |||||
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc< \ | |||||
need_load_from_const_mem>( \ | |||||
const int8_t* d_src, const int8_t* d_filter, \ | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
int* workspace, const convolution::ConvParam& param, \ | |||||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
float gamma, float scale, \ | |||||
const GemmCoord& threadblock_shape, \ | |||||
const GemmCoord& warp_shape, const int32_t access_size, \ | |||||
int stages, cudaStream_t stream); | |||||
INST(true); | |||||
INST(false); | |||||
#undef INST | |||||
/* ====== cutlass kernel wrapper for uint4 x int4 nhwc layout ====== */ | |||||
#if MEGDNN_TEGRA_X1 | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||||
const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||||
uint8_t* /* d_dst */, int* /* workspace */, | |||||
const convolution::ConvParam& /* param */, | |||||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
float /* beta */, float /* gamma */, float /* delta */, | |||||
float /* theta */, float /* scale */, | |||||
uint8_t /* src_zero_point */, | |||||
const GemmCoord& /* threadblock_shape */, | |||||
const GemmCoord& /* warp_shape */, | |||||
const int32_t /* access_size */, int /* stages */, | |||||
cudaStream_t /* stream */) {} | |||||
#else | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||||
const uint8_t* d_src, const int8_t* d_filter, | |||||
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||||
int* workspace, const convolution::ConvParam& param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float delta, float theta, float /* scale */, | |||||
uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||||
const GemmCoord& warp_shape, const int32_t access_size, | |||||
int stages, cudaStream_t stream) { | |||||
bool without_shared_load = | |||||
((param.co % threadblock_shape.n() == 0) && | |||||
(threadblock_shape.n() == 32 || threadblock_shape.n() == 64)); | |||||
int out_elements_per_access = | |||||
without_shared_load ? threadblock_shape.n() / 4 : 8; | |||||
#define RUN_CUTLASS_WRAPPER(stage_, access_size_, without_shared_load_) \ | |||||
using Convolution = cutlass::conv::device::Convolution< \ | |||||
cutlass::uint4b_t, cutlass::layout::TensorNHWC, cutlass::int4b_t, \ | |||||
cutlass::layout::TensorNCxHWx<access_size_>, ElementOutput, \ | |||||
cutlass::layout::TensorNHWC, int32_t, cutlass::layout::TensorNHWC, \ | |||||
int32_t, cutlass::conv::ConvType::kConvolution, \ | |||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
cutlass::conv::threadblock:: \ | |||||
ConvolutionFpropTransThreadblockSwizzle, \ | |||||
stage_, access_size_, access_size_, NeedLoadFromConstMem, \ | |||||
cutlass::arch::OpMultiplyAddSaturate, \ | |||||
cutlass::conv::ImplicitGemmMode::GEMM_TN, without_shared_load_>; \ | |||||
typename Convolution::ConvolutionParameter conv_param( \ | |||||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
return cutlass_convolution_wrapper<Convolution>( \ | |||||
reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||||
reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||||
conv_param, epilogue, stream, {src_zero_point}); | |||||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE( \ | |||||
threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ | |||||
warp_k_, stage_, access_size_, out_elements_per_access_, \ | |||||
without_shared_load_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | |||||
threadblock_shape.n() == threadblock_n_ && \ | |||||
threadblock_shape.k() == threadblock_k_ && \ | |||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
warp_shape.k() == warp_k_ && stages == stage_ && \ | |||||
access_size == access_size_ && \ | |||||
out_elements_per_access == out_elements_per_access_ && \ | |||||
without_shared_load == without_shared_load_) { \ | |||||
using ThreadBlockShape = \ | |||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_>; \ | |||||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
using ElementOutput = cutlass::uint4b_t; \ | |||||
using ElementAccumulator = int32_t; \ | |||||
using ElementBias = int32_t; \ | |||||
using ElementCompute = float; \ | |||||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; \ | |||||
switch (nonlinear_mode) { \ | |||||
case NonlineMode::IDENTITY: { \ | |||||
using EpilogueOp = cutlass::epilogue::thread:: \ | |||||
BiasAddLinearCombinationClamp< \ | |||||
ElementOutput, out_elements_per_access_, \ | |||||
ElementAccumulator, ElementBias, \ | |||||
ElementCompute>; \ | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ | |||||
delta + theta}; \ | |||||
RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||||
without_shared_load_); \ | |||||
} \ | |||||
case NonlineMode::RELU: { \ | |||||
using EpilogueOp = cutlass::epilogue::thread:: \ | |||||
BiasAddLinearCombinationReluClamp< \ | |||||
ElementOutput, out_elements_per_access_, \ | |||||
ElementAccumulator, ElementBias, \ | |||||
ElementCompute>; \ | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ | |||||
0, delta, theta}; \ | |||||
RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||||
without_shared_load_); \ | |||||
} \ | |||||
default: \ | |||||
megdnn_assert( \ | |||||
false, \ | |||||
"unsupported nonlinear mode for conv bias operator"); \ | |||||
} \ | |||||
} | |||||
#define DISPATCH_KERNEL \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, true); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, true); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, true); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 16, true); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 16, true); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 16, true); \ | |||||
megdnn_assert(false, \ | |||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
"(%dx%dx%d) and access_size (%d)", \ | |||||
threadblock_shape.m(), threadblock_shape.n(), \ | |||||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
warp_shape.k(), access_size); | |||||
DISPATCH_KERNEL; | |||||
#undef RUN_CUTLASS_WRAPPER | |||||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
#undef DISPATCH_KERNEL | |||||
} | |||||
#endif | |||||
#define INST(need_load_from_const_mem) \ | |||||
template void megdnn::cuda::cutlass_wrapper:: \ | |||||
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc< \ | |||||
need_load_from_const_mem>( \ | |||||
const uint8_t* d_src, const int8_t* d_filter, \ | |||||
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||||
int* workspace, const convolution::ConvParam& param, \ | |||||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
float gamma, float delta, float theta, float scale, \ | |||||
uint8_t src_zero_point, \ | |||||
const GemmCoord& threadblock_shape, \ | |||||
const GemmCoord& warp_shape, const int32_t access_size, \ | |||||
int stages, cudaStream_t stream); | |||||
INST(true); | |||||
INST(false); | |||||
#undef INST | |||||
// vim: syntax=cuda.doxygen |
@@ -38,7 +38,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
uint32_t /* nonlinear_mode */, float /* alpha */, | uint32_t /* nonlinear_mode */, float /* alpha */, | ||||
float /* beta */, float /* gamma */, float /* scale */, | float /* beta */, float /* gamma */, float /* scale */, | ||||
const GemmCoord& /* threadblock_shape */, | const GemmCoord& /* threadblock_shape */, | ||||
const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} | |||||
const GemmCoord& /* warp_shape */, int /* stages */, | |||||
cudaStream_t /* stream */) {} | |||||
#else | #else | ||||
template <bool NeedLoadFromConstMem> | template <bool NeedLoadFromConstMem> | ||||
void megdnn::cuda::cutlass_wrapper:: | void megdnn::cuda::cutlass_wrapper:: | ||||
@@ -48,15 +49,15 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
int* workspace, const convolution::ConvParam& param, | int* workspace, const convolution::ConvParam& param, | ||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | uint32_t nonlinear_mode, float alpha, float beta, float gamma, | ||||
float scale, const GemmCoord& threadblock_shape, | float scale, const GemmCoord& threadblock_shape, | ||||
const GemmCoord& warp_shape, cudaStream_t stream) { | |||||
const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | ||||
threadblock_k_, warp_m_, warp_n_, \ | threadblock_k_, warp_m_, warp_n_, \ | ||||
warp_k_) \ | |||||
warp_k_, stage_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | if (threadblock_shape.m() == threadblock_m_ && \ | ||||
threadblock_shape.n() == threadblock_n_ && \ | threadblock_shape.n() == threadblock_n_ && \ | ||||
threadblock_shape.k() == threadblock_k_ && \ | threadblock_shape.k() == threadblock_k_ && \ | ||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | ||||
warp_shape.k() == warp_k_) { \ | |||||
warp_shape.k() == warp_k_ && stages == stage_) { \ | |||||
using ThreadBlockShape = \ | using ThreadBlockShape = \ | ||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | ||||
threadblock_k_>; \ | threadblock_k_>; \ | ||||
@@ -71,8 +72,10 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | ||||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | ||||
cutlass::conv::threadblock:: \ | cutlass::conv::threadblock:: \ | ||||
ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||||
2, 16, 16, NeedLoadFromConstMem>; \ | |||||
ConvolutionFpropTransThreadblockSwizzle, \ | |||||
stage_, 16, 16, NeedLoadFromConstMem, \ | |||||
cutlass::arch::OpMultiplyAddSaturate, \ | |||||
cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ | |||||
typename Convolution::ConvolutionParameter conv_param( \ | typename Convolution::ConvolutionParameter conv_param( \ | ||||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | ||||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | ||||
@@ -82,13 +85,15 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
epilogue, stream); \ | epilogue, stream); \ | ||||
} | } | ||||
#define DISPATCH_KERNEL \ | #define DISPATCH_KERNEL \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 32, 16, 64); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 1); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 1); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 32, 64, 32, 1); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 1); \ | |||||
megdnn_assert(false, \ | megdnn_assert(false, \ | ||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | ||||
"(%dx%dx%d)", \ | "(%dx%dx%d)", \ | ||||
@@ -144,7 +149,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
uint32_t nonlinear_mode, float alpha, float beta, \ | uint32_t nonlinear_mode, float alpha, float beta, \ | ||||
float gamma, float scale, \ | float gamma, float scale, \ | ||||
const GemmCoord& threadblock_shape, \ | const GemmCoord& threadblock_shape, \ | ||||
const GemmCoord& warp_shape, cudaStream_t stream); | |||||
const GemmCoord& warp_shape, int stages, \ | |||||
cudaStream_t stream); | |||||
INST(true); | INST(true); | ||||
INST(false); | INST(false); | ||||
#undef INST | #undef INST | ||||
@@ -162,7 +168,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
uint32_t /* nonlinear_mode */, float /* alpha */, | uint32_t /* nonlinear_mode */, float /* alpha */, | ||||
float /* beta */, float /* gamma */, float /* scale */, | float /* beta */, float /* gamma */, float /* scale */, | ||||
const GemmCoord& /* threadblock_shape */, | const GemmCoord& /* threadblock_shape */, | ||||
const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} | |||||
const GemmCoord& /* warp_shape */, int /* stages */, | |||||
cudaStream_t /* stream */) {} | |||||
#else | #else | ||||
template <bool NeedLoadFromConstMem> | template <bool NeedLoadFromConstMem> | ||||
void megdnn::cuda::cutlass_wrapper:: | void megdnn::cuda::cutlass_wrapper:: | ||||
@@ -172,15 +179,15 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
int* workspace, const convolution::ConvParam& param, | int* workspace, const convolution::ConvParam& param, | ||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | uint32_t nonlinear_mode, float alpha, float beta, float gamma, | ||||
float scale, const GemmCoord& threadblock_shape, | float scale, const GemmCoord& threadblock_shape, | ||||
const GemmCoord& warp_shape, cudaStream_t stream) { | |||||
const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | ||||
threadblock_k_, warp_m_, warp_n_, \ | threadblock_k_, warp_m_, warp_n_, \ | ||||
warp_k_) \ | |||||
warp_k_, stage_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | if (threadblock_shape.m() == threadblock_m_ && \ | ||||
threadblock_shape.n() == threadblock_n_ && \ | threadblock_shape.n() == threadblock_n_ && \ | ||||
threadblock_shape.k() == threadblock_k_ && \ | threadblock_shape.k() == threadblock_k_ && \ | ||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | ||||
warp_shape.k() == warp_k_) { \ | |||||
warp_shape.k() == warp_k_ && stages == stage_) { \ | |||||
using ThreadBlockShape = \ | using ThreadBlockShape = \ | ||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | ||||
threadblock_k_>; \ | threadblock_k_>; \ | ||||
@@ -196,7 +203,7 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | ||||
cutlass::conv::threadblock:: \ | cutlass::conv::threadblock:: \ | ||||
ConvolutionFpropNCxHWxThreadblockSwizzle, \ | ConvolutionFpropNCxHWxThreadblockSwizzle, \ | ||||
2, 16, 16, NeedLoadFromConstMem>; \ | |||||
stage_, 16, 16, NeedLoadFromConstMem>; \ | |||||
typename Convolution::ConvolutionParameter conv_param( \ | typename Convolution::ConvolutionParameter conv_param( \ | ||||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | ||||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | ||||
@@ -206,13 +213,15 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
epilogue, stream); \ | epilogue, stream); \ | ||||
} | } | ||||
#define DISPATCH_KERNEL \ | #define DISPATCH_KERNEL \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 16, 32, 64); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 1); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 1); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 32, 64, 32, 1); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 1); \ | |||||
megdnn_assert(false, \ | megdnn_assert(false, \ | ||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | ||||
"(%dx%dx%d)", \ | "(%dx%dx%d)", \ | ||||
@@ -268,7 +277,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
uint32_t nonlinear_mode, float alpha, float beta, \ | uint32_t nonlinear_mode, float alpha, float beta, \ | ||||
float gamma, float scale, \ | float gamma, float scale, \ | ||||
const GemmCoord& threadblock_shape, \ | const GemmCoord& threadblock_shape, \ | ||||
const GemmCoord& warp_shape, cudaStream_t stream); | |||||
const GemmCoord& warp_shape, int stages, \ | |||||
cudaStream_t stream); | |||||
INST(true); | INST(true); | ||||
INST(false); | INST(false); | ||||
#undef INST | #undef INST | ||||
@@ -337,10 +347,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | ||||
megdnn_assert(false, \ | megdnn_assert(false, \ | ||||
@@ -468,10 +476,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | ||||
megdnn_assert(false, \ | megdnn_assert(false, \ | ||||
@@ -599,10 +605,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ | |||||
megdnn_assert(false, \ | megdnn_assert(false, \ | ||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | ||||
"(%dx%dx%d)", \ | "(%dx%dx%d)", \ | ||||
@@ -664,246 +668,6 @@ INST(true); | |||||
INST(false); | INST(false); | ||||
#undef INST | #undef INST | ||||
/* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */ | |||||
#if MEGDNN_TEGRA_X1 | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
int8_t* /* d_dst */, int* /* workspace */, | |||||
const convolution::ConvParam& /* param */, | |||||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
float /* beta */, float /* gamma */, float /* scale */, | |||||
const GemmCoord& /* threadblock_shape */, | |||||
const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} | |||||
#else | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
const int8_t* d_src, const int8_t* d_filter, | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
int* workspace, const convolution::ConvParam& param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float scale, const GemmCoord& threadblock_shape, | |||||
const GemmCoord& warp_shape, cudaStream_t stream) { | |||||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_, warp_m_, warp_n_, \ | |||||
warp_k_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | |||||
threadblock_shape.n() == threadblock_n_ && \ | |||||
threadblock_shape.k() == threadblock_k_ && \ | |||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
warp_shape.k() == warp_k_) { \ | |||||
using ThreadBlockShape = \ | |||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_>; \ | |||||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
using Convolution = cutlass::conv::device::Convolution< \ | |||||
cutlass::int4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||||
cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||||
ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
cutlass::conv::ConvType::kConvolution, \ | |||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
cutlass::conv::threadblock:: \ | |||||
ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||||
2, 32, 32, NeedLoadFromConstMem>; \ | |||||
typename Convolution::ConvolutionParameter conv_param( \ | |||||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
return cutlass_convolution_wrapper<Convolution>( \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||||
reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, \ | |||||
conv_param, epilogue, stream); \ | |||||
} | |||||
#define DISPATCH_KERNEL \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 128, 64, 64, 128); \ | |||||
megdnn_assert(false, \ | |||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
"(%dx%dx%d)", \ | |||||
threadblock_shape.m(), threadblock_shape.n(), \ | |||||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
warp_shape.k()); | |||||
using ElementOutput = cutlass::int4b_t; | |||||
using ElementAccumulator = int32_t; | |||||
using ElementBias = int32_t; | |||||
using ElementCompute = float; | |||||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
switch (nonlinear_mode) { | |||||
case NonlineMode::IDENTITY: { | |||||
using EpilogueOp = | |||||
cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::RELU: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationReluClamp< | |||||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::H_SWISH: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationHSwishClamp< | |||||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
default: | |||||
megdnn_assert(false, | |||||
"unsupported nonlinear mode for conv bias operator"); | |||||
} | |||||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
#undef DISPATCH_KERNEL | |||||
} | |||||
#endif | |||||
#define INST(need_load_from_const_mem) \ | |||||
template void megdnn::cuda::cutlass_wrapper:: \ | |||||
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||||
need_load_from_const_mem>( \ | |||||
const int8_t* d_src, const int8_t* d_filter, \ | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
int* workspace, const convolution::ConvParam& param, \ | |||||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
float gamma, float scale, \ | |||||
const GemmCoord& threadblock_shape, \ | |||||
const GemmCoord& warp_shape, cudaStream_t stream); | |||||
INST(true); | |||||
#undef INST | |||||
/* ====== cutlass kernel wrapper for uint4 x int4 nchw64 layout ====== */ | |||||
#if MEGDNN_TEGRA_X1 | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||||
uint8_t* /* d_dst */, int* /* workspace */, | |||||
const convolution::ConvParam& /* param */, | |||||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
float /* beta */, float /* gamma */, float /* delta */, | |||||
float /* theta */, float /* scale */, | |||||
uint8_t /* src_zero_point */, | |||||
const GemmCoord& /* threadblock_shape */, | |||||
const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} | |||||
#else | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
const uint8_t* d_src, const int8_t* d_filter, | |||||
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||||
int* workspace, const convolution::ConvParam& param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float delta, float theta, float /* scale */, | |||||
uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||||
const GemmCoord& warp_shape, cudaStream_t stream) { | |||||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_, warp_m_, warp_n_, \ | |||||
warp_k_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | |||||
threadblock_shape.n() == threadblock_n_ && \ | |||||
threadblock_shape.k() == threadblock_k_ && \ | |||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
warp_shape.k() == warp_k_) { \ | |||||
using ThreadBlockShape = \ | |||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_>; \ | |||||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
using Convolution = cutlass::conv::device::Convolution< \ | |||||
cutlass::uint4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||||
cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||||
ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
cutlass::conv::ConvType::kConvolution, \ | |||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
cutlass::conv::threadblock:: \ | |||||
ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||||
2, 32, 32, NeedLoadFromConstMem>; \ | |||||
typename Convolution::ConvolutionParameter conv_param( \ | |||||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
return cutlass_convolution_wrapper<Convolution>( \ | |||||
reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||||
reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||||
conv_param, epilogue, stream, {src_zero_point}); \ | |||||
} | |||||
#define DISPATCH_KERNEL \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 128, 64, 64, 128); \ | |||||
megdnn_assert(false, \ | |||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
"(%dx%dx%d)", \ | |||||
threadblock_shape.m(), threadblock_shape.n(), \ | |||||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
warp_shape.k()); | |||||
using ElementOutput = cutlass::uint4b_t; | |||||
using ElementAccumulator = int32_t; | |||||
using ElementBias = int32_t; | |||||
using ElementCompute = float; | |||||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
switch (nonlinear_mode) { | |||||
case NonlineMode::IDENTITY: { | |||||
using EpilogueOp = | |||||
cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
delta + theta}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::RELU: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationReluClamp< | |||||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
0, delta, theta}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
default: | |||||
megdnn_assert(false, | |||||
"unsupported nonlinear mode for conv bias operator"); | |||||
} | |||||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
#undef DISPATCH_KERNEL | |||||
} | |||||
#endif | |||||
#define INST(need_load_from_const_mem) \ | |||||
template void megdnn::cuda::cutlass_wrapper:: \ | |||||
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||||
need_load_from_const_mem>( \ | |||||
const uint8_t* d_src, const int8_t* d_filter, \ | |||||
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||||
int* workspace, const convolution::ConvParam& param, \ | |||||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
float gamma, float delta, float theta, float scale, \ | |||||
uint8_t src_zero_point, \ | |||||
const GemmCoord& threadblock_shape, \ | |||||
const GemmCoord& warp_shape, cudaStream_t stream); | |||||
INST(true); | |||||
#undef INST | |||||
/* ===== cutlass kernel wrapper for nchw4 layout and nhwc output ===== */ | /* ===== cutlass kernel wrapper for nchw4 layout and nhwc output ===== */ | ||||
#if MEGDNN_TEGRA_X1 | #if MEGDNN_TEGRA_X1 | ||||
template <bool signedness> | template <bool signedness> | ||||
@@ -970,10 +734,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | ||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | ||||
megdnn_assert(false, \ | megdnn_assert(false, \ | ||||
@@ -1039,262 +801,4 @@ INST(true); | |||||
INST(false); | INST(false); | ||||
#undef INST | #undef INST | ||||
/* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */ | |||||
#if MEGDNN_TEGRA_X1 | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||||
const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
int8_t* /* d_dst */, int* /* workspace */, | |||||
const convolution::ConvParam& /* param */, | |||||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
float /* beta */, float /* gamma */, float /* scale */, | |||||
const GemmCoord& /* threadblock_shape */, | |||||
const GemmCoord& /* warp_shape */, | |||||
const int32_t /* access_size */, cudaStream_t /* stream */) {} | |||||
#else | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||||
const int8_t* d_src, const int8_t* d_filter, | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
int* workspace, const convolution::ConvParam& param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float scale, const GemmCoord& threadblock_shape, | |||||
const GemmCoord& warp_shape, const int32_t access_size, | |||||
cudaStream_t stream) { | |||||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_, warp_m_, warp_n_, \ | |||||
warp_k_, access_size_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | |||||
threadblock_shape.n() == threadblock_n_ && \ | |||||
threadblock_shape.k() == threadblock_k_ && \ | |||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
warp_shape.k() == warp_k_ && access_size == access_size_) { \ | |||||
using ThreadBlockShape = \ | |||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_>; \ | |||||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
using Convolution = cutlass::conv::device::Convolution< \ | |||||
cutlass::int4b_t, cutlass::layout::TensorNHWC, \ | |||||
cutlass::int4b_t, cutlass::layout::TensorNCxHWx<access_size_>, \ | |||||
ElementOutput, cutlass::layout::TensorNHWC, int32_t, \ | |||||
cutlass::layout::TensorNHWC, int32_t, \ | |||||
cutlass::conv::ConvType::kConvolution, \ | |||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
cutlass::conv::threadblock:: \ | |||||
ConvolutionFpropNHWCThreadblockSwizzle, \ | |||||
2, access_size_, access_size_, NeedLoadFromConstMem, \ | |||||
cutlass::arch::OpMultiplyAddSaturate, \ | |||||
cutlass::conv::ImplicitGemmMode::GEMM_TN>; \ | |||||
typename Convolution::ConvolutionParameter conv_param( \ | |||||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
return cutlass_convolution_wrapper<Convolution>( \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||||
reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, \ | |||||
conv_param, epilogue, stream); \ | |||||
} | |||||
#define DISPATCH_KERNEL \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 32); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 8); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 32); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 8); \ | |||||
megdnn_assert(false, \ | |||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
"(%dx%dx%d) and access_size (%d)", \ | |||||
threadblock_shape.m(), threadblock_shape.n(), \ | |||||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
warp_shape.k(), access_size); | |||||
using ElementOutput = cutlass::int4b_t; | |||||
using ElementAccumulator = int32_t; | |||||
using ElementBias = int32_t; | |||||
using ElementCompute = float; | |||||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
switch (nonlinear_mode) { | |||||
case NonlineMode::IDENTITY: { | |||||
using EpilogueOp = | |||||
cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::RELU: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationReluClamp< | |||||
ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::H_SWISH: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationHSwishClamp< | |||||
ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
default: | |||||
megdnn_assert(false, | |||||
"unsupported nonlinear mode for conv bias operator"); | |||||
} | |||||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
#undef DISPATCH_KERNEL | |||||
} | |||||
#endif | |||||
#define INST(need_load_from_const_mem) \ | |||||
template void megdnn::cuda::cutlass_wrapper:: \ | |||||
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc< \ | |||||
need_load_from_const_mem>( \ | |||||
const int8_t* d_src, const int8_t* d_filter, \ | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
int* workspace, const convolution::ConvParam& param, \ | |||||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
float gamma, float scale, \ | |||||
const GemmCoord& threadblock_shape, \ | |||||
const GemmCoord& warp_shape, const int32_t access_size, \ | |||||
cudaStream_t stream); | |||||
INST(true); | |||||
INST(false); | |||||
#undef INST | |||||
#if MEGDNN_TEGRA_X1 | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||||
const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||||
uint8_t* /* d_dst */, int* /* workspace */, | |||||
const convolution::ConvParam& /* param */, | |||||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
float /* beta */, float /* gamma */, float /* delta */, | |||||
float /* theta */, float /* scale */, | |||||
uint8_t /* src_zero_point */, | |||||
const GemmCoord& /* threadblock_shape */, | |||||
const GemmCoord& /* warp_shape */, | |||||
const int32_t /* access_size */, cudaStream_t /* stream */) {} | |||||
#else | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||||
const uint8_t* d_src, const int8_t* d_filter, | |||||
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||||
int* workspace, const convolution::ConvParam& param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float delta, float theta, float /* scale */, | |||||
uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||||
const GemmCoord& warp_shape, const int32_t access_size, | |||||
cudaStream_t stream) { | |||||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_, warp_m_, warp_n_, \ | |||||
warp_k_, access_size_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | |||||
threadblock_shape.n() == threadblock_n_ && \ | |||||
threadblock_shape.k() == threadblock_k_ && \ | |||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
warp_shape.k() == warp_k_ && access_size == access_size_) { \ | |||||
using ThreadBlockShape = \ | |||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_>; \ | |||||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
using Convolution = cutlass::conv::device::Convolution< \ | |||||
cutlass::uint4b_t, cutlass::layout::TensorNHWC, \ | |||||
cutlass::int4b_t, cutlass::layout::TensorNCxHWx<access_size_>, \ | |||||
ElementOutput, cutlass::layout::TensorNHWC, int32_t, \ | |||||
cutlass::layout::TensorNHWC, int32_t, \ | |||||
cutlass::conv::ConvType::kConvolution, \ | |||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
cutlass::conv::threadblock:: \ | |||||
ConvolutionFpropNHWCThreadblockSwizzle, \ | |||||
2, access_size_, access_size_, NeedLoadFromConstMem, \ | |||||
cutlass::arch::OpMultiplyAddSaturate, \ | |||||
cutlass::conv::ImplicitGemmMode::GEMM_TN>; \ | |||||
typename Convolution::ConvolutionParameter conv_param( \ | |||||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
return cutlass_convolution_wrapper<Convolution>( \ | |||||
reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||||
reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||||
conv_param, epilogue, stream, {src_zero_point}); \ | |||||
} | |||||
#define DISPATCH_KERNEL \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 32); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 8); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 32); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 8); \ | |||||
megdnn_assert(false, \ | |||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
"(%dx%dx%d) and access_size (%d)", \ | |||||
threadblock_shape.m(), threadblock_shape.n(), \ | |||||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
warp_shape.k(), access_size); | |||||
using ElementOutput = cutlass::uint4b_t; | |||||
using ElementAccumulator = int32_t; | |||||
using ElementBias = int32_t; | |||||
using ElementCompute = float; | |||||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
switch (nonlinear_mode) { | |||||
case NonlineMode::IDENTITY: { | |||||
using EpilogueOp = | |||||
cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
delta + theta}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::RELU: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationReluClamp< | |||||
ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
0, delta, theta}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
default: | |||||
megdnn_assert(false, | |||||
"unsupported nonlinear mode for conv bias operator"); | |||||
} | |||||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
#undef DISPATCH_KERNEL | |||||
} | |||||
#endif | |||||
#define INST(need_load_from_const_mem) \ | |||||
template void megdnn::cuda::cutlass_wrapper:: \ | |||||
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc< \ | |||||
need_load_from_const_mem>( \ | |||||
const uint8_t* d_src, const int8_t* d_filter, \ | |||||
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||||
int* workspace, const convolution::ConvParam& param, \ | |||||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
float gamma, float delta, float theta, float scale, \ | |||||
uint8_t src_zero_point, \ | |||||
const GemmCoord& threadblock_shape, \ | |||||
const GemmCoord& warp_shape, const int32_t access_size, \ | |||||
cudaStream_t stream); | |||||
INST(true); | |||||
INST(false); | |||||
#undef INST | |||||
// vim: syntax=cuda.doxygen | // vim: syntax=cuda.doxygen |
@@ -0,0 +1,194 @@ | |||||
/** | |||||
* \file dnn/src/cuda/conv_bias/cutlass_reorder_filter.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | |||||
#include "src/cuda/query_blocksize.cuh" | |||||
#include "src/cuda/integer_subbyte_utils.cuh" | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
using namespace cutlass_wrapper; | |||||
namespace { | |||||
template <uint32_t size_bits, uint32_t interleaved> | |||||
__device__ __forceinline__ void reorder_ncxhwx_imma_filter_func( | |||||
int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH, | |||||
uint32_t FW, uint32_t lane, bool trans_oc) { | |||||
static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||||
static constexpr uint32_t threads_per_interleaved = | |||||
interleaved / elements_per_lane; | |||||
static constexpr uint32_t instruction_shape_col = 8; | |||||
// 4 threads per Quad | |||||
static constexpr uint32_t elements_per_thread = instruction_shape_col / 4; | |||||
// 4 threads per Quad | |||||
static constexpr uint32_t reordered_elements_per_thread = interleaved / 4; | |||||
uint32_t id = lane / threads_per_interleaved; | |||||
uint32_t residue = lane % threads_per_interleaved; | |||||
uint32_t ICx = IC / interleaved; | |||||
uint32_t row = id / (ICx * FH * FW); | |||||
uint32_t col = id - row * ICx * FH * FW; | |||||
// transpose ncxhwx to cxhwnx | |||||
uint32_t src_offset = id * interleaved + residue * elements_per_lane; | |||||
row = (trans_oc) ? (row / interleaved) * interleaved + | |||||
((row % reordered_elements_per_thread) / | |||||
elements_per_thread) * | |||||
instruction_shape_col + | |||||
((row % interleaved) / | |||||
reordered_elements_per_thread) * | |||||
elements_per_thread + | |||||
(row % elements_per_thread) | |||||
: row; | |||||
uint32_t dst_offset = | |||||
(col * OC + row) * interleaved + residue * elements_per_lane; | |||||
*(reinterpret_cast<int4*>(dst + dst_offset * size_bits / 8)) = | |||||
*(reinterpret_cast<const int4*>(src + src_offset * size_bits / 8)); | |||||
} | |||||
template <uint32_t size_bits, uint32_t interleaved> | |||||
__global__ void reorder_ncxhwx_imma_filter_kernel( | |||||
int8_t* __restrict__ dst_filter, const int8_t* __restrict__ src_filter, | |||||
uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, bool trans_oc) { | |||||
static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||||
const uint32_t size = OC * IC * FH * FW / elements_per_lane; | |||||
uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x; | |||||
if (lane < size) { | |||||
reorder_ncxhwx_imma_filter_func<size_bits, interleaved>( | |||||
dst_filter, src_filter, OC, IC, FH, FW, lane, trans_oc); | |||||
} | |||||
} | |||||
template <uint32_t size_bits, uint32_t alignbits, uint32_t interleaved> | |||||
__device__ __forceinline__ void reorder_nhwc_imma_filter_func( | |||||
int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH, | |||||
uint32_t FW, uint32_t lane, bool trans_oc) { | |||||
static constexpr uint32_t elements_per_access = alignbits / size_bits; | |||||
static constexpr uint32_t instruction_shape_col = 8; | |||||
// 4 threads per Quad | |||||
static constexpr uint32_t elements_per_thread = instruction_shape_col / 4; | |||||
// 4 threads per Quad | |||||
static constexpr uint32_t reordered_elements_per_thread = interleaved / 4; | |||||
uint32_t ICx = IC / elements_per_access; | |||||
uint32_t k = lane / (ICx * FH * FW); | |||||
uint32_t cxrs = lane - k * ICx * FH * FW; | |||||
uint32_t rs = cxrs / ICx; | |||||
uint32_t cx = cxrs - rs * ICx; | |||||
// transpose nhwc to ncxhwx | |||||
uint32_t src_offset = lane * elements_per_access; | |||||
// reorder k | |||||
k = (trans_oc) | |||||
? (k / interleaved) * interleaved + | |||||
((k % reordered_elements_per_thread) / | |||||
elements_per_thread) * | |||||
instruction_shape_col + | |||||
((k % interleaved) / reordered_elements_per_thread) * | |||||
elements_per_thread + | |||||
(k % elements_per_thread) | |||||
: k; | |||||
uint32_t dst_offset = | |||||
(k * ICx * FH * FW + cx * FH * FW + rs) * elements_per_access; | |||||
if (alignbits == 32) { | |||||
*(reinterpret_cast<int*>(dst + dst_offset * size_bits / 8)) = *( | |||||
reinterpret_cast<const int*>(src + src_offset * size_bits / 8)); | |||||
} else if (alignbits == 64) { | |||||
*(reinterpret_cast<int2*>(dst + dst_offset * size_bits / 8)) = | |||||
*(reinterpret_cast<const int2*>(src + | |||||
src_offset * size_bits / 8)); | |||||
} else { | |||||
*(reinterpret_cast<int4*>(dst + dst_offset * size_bits / 8)) = | |||||
*(reinterpret_cast<const int4*>(src + | |||||
src_offset * size_bits / 8)); | |||||
} | |||||
} | |||||
template <uint32_t size_bits, uint32_t alignbits, uint32_t interleaved> | |||||
__global__ void reorder_nhwc_imma_filter_kernel( | |||||
int8_t* __restrict__ dst_filter, const int8_t* __restrict__ src_filter, | |||||
uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, bool trans_oc) { | |||||
static constexpr uint32_t elements_per_access = alignbits / size_bits; | |||||
const uint32_t size = OC * IC * FH * FW / elements_per_access; | |||||
uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x; | |||||
if (lane < size) { | |||||
reorder_nhwc_imma_filter_func<size_bits, alignbits, interleaved>( | |||||
dst_filter, src_filter, OC, IC, FH, FW, lane, trans_oc); | |||||
} | |||||
} | |||||
} // namespace | |||||
template <uint32_t size_bits, uint32_t interleaved> | |||||
void megdnn::cuda::cutlass_wrapper::reorder_ncxhwx_imma_filter( | |||||
int8_t* dst_filter, const int8_t* src_filter, uint32_t OC, uint32_t IC, | |||||
uint32_t FH, uint32_t FW, bool trans_oc, cudaStream_t stream) { | |||||
static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||||
uint32_t nr_threads = | |||||
query_blocksize_for_kernel(reinterpret_cast<const void*>( | |||||
reorder_ncxhwx_imma_filter_kernel<size_bits, interleaved>)); | |||||
uint32_t vthreads = DIVUP(OC * IC * FH * FW, elements_per_lane); | |||||
nr_threads = std::min(nr_threads, vthreads); | |||||
uint32_t nr_blocks = DIVUP(vthreads, nr_threads); | |||||
reorder_ncxhwx_imma_filter_kernel<size_bits, interleaved> | |||||
<<<nr_blocks, nr_threads, 0, stream>>>(dst_filter, src_filter, OC, | |||||
IC, FH, FW, trans_oc); | |||||
after_kernel_launch(); | |||||
} | |||||
template <uint32_t size_bits, uint32_t alignbits> | |||||
void megdnn::cuda::cutlass_wrapper::reorder_nhwc_imma_filter( | |||||
int8_t* dst_filter, const int8_t* src_filter, uint32_t OC, uint32_t IC, | |||||
uint32_t FH, uint32_t FW, bool trans_oc, uint32_t oc_interleaved, | |||||
cudaStream_t stream) { | |||||
static constexpr uint32_t elements_per_access = alignbits / size_bits; | |||||
uint32_t nr_threads = | |||||
query_blocksize_for_kernel(reinterpret_cast<const void*>( | |||||
reorder_nhwc_imma_filter_kernel<size_bits, alignbits, 32>)); | |||||
uint32_t vthreads = DIVUP(OC * IC * FH * FW, elements_per_access); | |||||
nr_threads = std::min(nr_threads, vthreads); | |||||
uint32_t nr_blocks = DIVUP(vthreads, nr_threads); | |||||
if (oc_interleaved == 32) { | |||||
reorder_nhwc_imma_filter_kernel<size_bits, alignbits, 32> | |||||
<<<nr_blocks, nr_threads, 0, stream>>>( | |||||
dst_filter, src_filter, OC, IC, FH, FW, trans_oc); | |||||
} else { | |||||
reorder_nhwc_imma_filter_kernel<size_bits, alignbits, 64> | |||||
<<<nr_blocks, nr_threads, 0, stream>>>( | |||||
dst_filter, src_filter, OC, IC, FH, FW, trans_oc); | |||||
} | |||||
after_kernel_launch(); | |||||
} | |||||
#define INST(_size_bits, _interleaved) \ | |||||
template void megdnn::cuda::cutlass_wrapper::reorder_ncxhwx_imma_filter< \ | |||||
_size_bits, _interleaved>(int8_t * dst_filter, \ | |||||
const int8_t* src_filter, uint32_t OC, \ | |||||
uint32_t IC, uint32_t FH, uint32_t FW, \ | |||||
bool trans_oc, cudaStream_t stream); | |||||
INST(8, 32) | |||||
INST(4, 64) | |||||
#undef INST | |||||
#define INST(_size_bits, _alignbits) \ | |||||
template void megdnn::cuda::cutlass_wrapper::reorder_nhwc_imma_filter< \ | |||||
_size_bits, _alignbits>( \ | |||||
int8_t * dst_filter, const int8_t* src_filter, uint32_t OC, \ | |||||
uint32_t IC, uint32_t FH, uint32_t FW, bool trans_oc, \ | |||||
uint32_t oc_interleaved, cudaStream_t stream); | |||||
INST(4, 32) | |||||
INST(4, 64) | |||||
INST(4, 128) | |||||
#undef INST | |||||
// vim: syntax=cuda.doxygen |
@@ -0,0 +1,33 @@ | |||||
/** | |||||
* \file dnn/src/cuda/conv_bias/cutlass_reorder_filter.cuh | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/cuda/utils.cuh" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace cutlass_wrapper { | |||||
template <uint32_t size_bits, uint32_t interleaved> | |||||
void reorder_ncxhwx_imma_filter(int8_t* dst_filter, const int8_t* src_filter, | |||||
uint32_t OC, uint32_t IC, uint32_t FH, | |||||
uint32_t FW, bool trans_oc, | |||||
cudaStream_t stream); | |||||
template <uint32_t size_bits, uint32_t alignbits> | |||||
void reorder_nhwc_imma_filter(int8_t* dst_filter, const int8_t* src_filter, | |||||
uint32_t OC, uint32_t IC, uint32_t FH, | |||||
uint32_t FW, bool trans_oc, | |||||
uint32_t oc_interleaved, cudaStream_t stream); | |||||
} // namespace cutlass_wrapper | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -102,7 +102,7 @@ void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::do_exec( | |||||
reinterpret_cast<int8_t*>(z_ptr), | reinterpret_cast<int8_t*>(z_ptr), | ||||
reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | ||||
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | ||||
threadblock_shape, warp_shape, stream); | |||||
threadblock_shape, warp_shape, m_algo_param.stage, stream); | |||||
} | } | ||||
#endif | #endif | ||||
@@ -104,7 +104,7 @@ void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||||
reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | ||||
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | ||||
threadblock_shape, warp_shape, m_algo_param.access_size, | threadblock_shape, warp_shape, m_algo_param.access_size, | ||||
stream); | |||||
m_algo_param.stage, stream); | |||||
} else { | } else { | ||||
cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_nhwc<true>( | cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_nhwc<true>( | ||||
reinterpret_cast<int8_t*>(args.src_tensor->raw_ptr), | reinterpret_cast<int8_t*>(args.src_tensor->raw_ptr), | ||||
@@ -114,7 +114,7 @@ void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||||
reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | ||||
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | ||||
threadblock_shape, warp_shape, m_algo_param.access_size, | threadblock_shape, warp_shape, m_algo_param.access_size, | ||||
stream); | |||||
m_algo_param.stage, stream); | |||||
} | } | ||||
} | } | ||||
#endif | #endif | ||||
@@ -12,6 +12,7 @@ | |||||
#include "./algo.h" | #include "./algo.h" | ||||
#include "src/common/conv_bias.h" | #include "src/common/conv_bias.h" | ||||
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | |||||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | ||||
#include "src/cuda/conv_bias/reduce_filter.cuh" | #include "src/cuda/conv_bias/reduce_filter.cuh" | ||||
#include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
@@ -121,41 +122,26 @@ void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::exec( | |||||
std::string ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::to_string( | std::string ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::to_string( | ||||
AlgoParam algo_param) { | AlgoParam algo_param) { | ||||
return ssprintf("%dX%dX%d_%dX%dX%d", algo_param.threadblock_m, | |||||
return ssprintf("%dX%dX%d_%dX%dX%d_%d", algo_param.threadblock_m, | |||||
algo_param.threadblock_n, algo_param.threadblock_k, | algo_param.threadblock_n, algo_param.threadblock_k, | ||||
algo_param.warp_m, algo_param.warp_n, algo_param.warp_k); | |||||
algo_param.warp_m, algo_param.warp_n, algo_param.warp_k, | |||||
algo_param.stage); | |||||
} | } | ||||
void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::reorder_filter( | void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::reorder_filter( | ||||
const ExecArgs& args, void* reordered_filter) const { | const ExecArgs& args, void* reordered_filter) const { | ||||
auto&& param = args.opr->param(); | |||||
size_t ci = args.src_layout->operator[](1) * 64; | |||||
size_t co = args.dst_layout->operator[](1) * 64; | |||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
size_t n = args.src_layout->operator[](0), | |||||
ci = args.src_layout->operator[](1) * 64, | |||||
hi = args.src_layout->operator[](2), | |||||
wi = args.src_layout->operator[](3); | |||||
size_t co = args.dst_layout->operator[](1) * 64, | |||||
ho = args.dst_layout->operator[](2), | |||||
wo = args.dst_layout->operator[](3); | |||||
UNPACK_CONV_PARAMETER(fm, param); | |||||
MARK_USED_VAR; | |||||
// filter: KCRS64 => CRSK64 | |||||
TensorLayout src{{co, ci / 64, fh, fw, 64}, dtype::QuantizedS4()}; | |||||
src.init_contiguous_stride(); | |||||
TensorLayout dst = src; | |||||
dst.stride[0] = 64; | |||||
dst.stride[1] = co * fh * fw * 64; | |||||
dst.stride[2] = co * fw * 64; | |||||
dst.stride[3] = co * 64; | |||||
dst.stride[4] = 1; | |||||
TensorND ts_src, ts_dst; | |||||
ts_src.raw_ptr = args.filter_tensor->raw_ptr; | |||||
ts_src.layout = src; | |||||
ts_dst.raw_ptr = reordered_filter; | |||||
ts_dst.layout = dst; | |||||
auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | |||||
transpose->exec(ts_src, ts_dst); | |||||
size_t fh = fm.spatial[0], fw = fm.spatial[1]; | |||||
cudaStream_t stream = cuda_stream(args.opr->handle()); | |||||
// filter: KCRS64 => CRSK64 and reorder oc | |||||
cutlass_wrapper::reorder_ncxhwx_imma_filter<4, 64>( | |||||
reinterpret_cast<int8_t*>(reordered_filter), | |||||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, fh, | |||||
fw, true, stream); | |||||
} | } | ||||
#endif | #endif | ||||
@@ -12,6 +12,7 @@ | |||||
#include "./algo.h" | #include "./algo.h" | ||||
#include "src/common/conv_bias.h" | #include "src/common/conv_bias.h" | ||||
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | |||||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | ||||
#include "src/cuda/conv_bias/reduce_filter.cuh" | #include "src/cuda/conv_bias/reduce_filter.cuh" | ||||
#include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
@@ -128,10 +129,10 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::exec( | |||||
std::string ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::to_string( | std::string ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::to_string( | ||||
AlgoParam algo_param) { | AlgoParam algo_param) { | ||||
return ssprintf("%dX%dX%d_%dX%dX%d_%d", algo_param.threadblock_m, | |||||
return ssprintf("%dX%dX%d_%dX%dX%d_%d_%d", algo_param.threadblock_m, | |||||
algo_param.threadblock_n, algo_param.threadblock_k, | algo_param.threadblock_n, algo_param.threadblock_k, | ||||
algo_param.warp_m, algo_param.warp_n, algo_param.warp_k, | algo_param.warp_m, algo_param.warp_n, algo_param.warp_k, | ||||
algo_param.access_size); | |||||
algo_param.stage, algo_param.access_size); | |||||
} | } | ||||
void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::reorder_filter( | void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::reorder_filter( | ||||
@@ -142,17 +143,32 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::reorder_filter( | |||||
fh = args.filter_layout->operator[](1), | fh = args.filter_layout->operator[](1), | ||||
fw = args.filter_layout->operator[](2); | fw = args.filter_layout->operator[](2); | ||||
// reformat grad from nhwc to ncxhwx | |||||
TensorLayout exec_src{{co, fh, fw, ci / iterleaved, (size_t)iterleaved / 2}, | |||||
dtype::Int8()}; | |||||
TensorLayout exec_dst{{co, ci / iterleaved, fh, fw, (size_t)iterleaved / 2}, | |||||
dtype::Int8()}; | |||||
exec_src = exec_src.dimshuffle({0, 3, 1, 2, 4}); | |||||
cudaStream_t stream = cuda_stream(args.opr->handle()); | |||||
auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>(); | |||||
relayout->exec({args.filter_tensor->raw_ptr, exec_src}, | |||||
{reordered_filter, exec_dst}); | |||||
// reformat filter from nhwc to ncxhwx and reorder oc | |||||
// use trans_oc threadblock_n must be 32 or 64 | |||||
bool trans_oc = ((co % m_algo_param.threadblock_n == 0) && | |||||
(m_algo_param.threadblock_n == 32 || | |||||
m_algo_param.threadblock_n == 64)); | |||||
uint32_t oc_iterleave = (m_algo_param.threadblock_n == 64) ? 64 : 32; | |||||
if (iterleaved == 8) { | |||||
cutlass_wrapper::reorder_nhwc_imma_filter<4, 32>( | |||||
reinterpret_cast<int8_t*>(reordered_filter), | |||||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, | |||||
fh, fw, trans_oc, oc_iterleave, stream); | |||||
} else if (iterleaved == 16) { | |||||
cutlass_wrapper::reorder_nhwc_imma_filter<4, 64>( | |||||
reinterpret_cast<int8_t*>(reordered_filter), | |||||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, | |||||
fh, fw, trans_oc, oc_iterleave, stream); | |||||
} else { | |||||
megdnn_assert(iterleaved == 32); | |||||
cutlass_wrapper::reorder_nhwc_imma_filter<4, 128>( | |||||
reinterpret_cast<int8_t*>(reordered_filter), | |||||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, | |||||
fh, fw, trans_oc, oc_iterleave, stream); | |||||
} | |||||
} | } | ||||
#endif | #endif | ||||
@@ -11,6 +11,7 @@ | |||||
*/ | */ | ||||
#include "./algo.h" | #include "./algo.h" | ||||
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | |||||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | ||||
#include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
@@ -110,11 +111,14 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
size_t ho = args.dst_layout->operator[](2), | size_t ho = args.dst_layout->operator[](2), | ||||
wo = args.dst_layout->operator[](3); | wo = args.dst_layout->operator[](3); | ||||
size_t co; | size_t co; | ||||
bool trans_oc; | |||||
if (param.format == Format::NCHW32) { | if (param.format == Format::NCHW32) { | ||||
co = args.dst_layout->operator[](1) * 32; | co = args.dst_layout->operator[](1) * 32; | ||||
trans_oc = true; | |||||
} else { | } else { | ||||
megdnn_assert(param.format == Format::NCHW32_NCHW4); | megdnn_assert(param.format == Format::NCHW32_NCHW4); | ||||
co = args.dst_layout->operator[](1) * 4; | co = args.dst_layout->operator[](1) * 4; | ||||
trans_oc = false; | |||||
} | } | ||||
UNPACK_CONV_PARAMETER(fm, param); | UNPACK_CONV_PARAMETER(fm, param); | ||||
MARK_USED_VAR | MARK_USED_VAR | ||||
@@ -123,23 +127,11 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
int8_t* filter_ptr = nullptr; | int8_t* filter_ptr = nullptr; | ||||
if (args.preprocessed_filter == nullptr) { | if (args.preprocessed_filter == nullptr) { | ||||
filter_ptr = reinterpret_cast<int8_t*>(args.workspace.raw_ptr); | filter_ptr = reinterpret_cast<int8_t*>(args.workspace.raw_ptr); | ||||
// reformat filter from nchw32 to chwn32 | |||||
TensorLayout src{{co, ci / 32, fh, fw, 32}, dtype::Int8()}; | |||||
src.init_contiguous_stride(); | |||||
TensorLayout dst = src; | |||||
dst.stride[0] = 32; | |||||
dst.stride[1] = co * fh * fw * 32; | |||||
dst.stride[2] = co * fw * 32; | |||||
dst.stride[3] = co * 32; | |||||
dst.stride[4] = 1; | |||||
TensorND ts_src, ts_dst; | |||||
ts_src.raw_ptr = args.filter_tensor->raw_ptr; | |||||
ts_src.layout = src; | |||||
ts_dst.raw_ptr = args.workspace.raw_ptr; | |||||
ts_dst.layout = dst; | |||||
auto&& transpose = | |||||
args.opr->handle()->create_operator<RelayoutForward>(); | |||||
transpose->exec(ts_src, ts_dst); | |||||
// filter: KCRS32 => CRSK32 and reorder oc | |||||
cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>( | |||||
filter_ptr, | |||||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, | |||||
fh, fw, trans_oc, stream); | |||||
} else { | } else { | ||||
filter_ptr = reinterpret_cast<int8_t*>( | filter_ptr = reinterpret_cast<int8_t*>( | ||||
args.preprocessed_filter->tensors[0].raw_ptr); | args.preprocessed_filter->tensors[0].raw_ptr); | ||||
@@ -182,7 +174,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | ||||
m_algo_param.warp_n, | m_algo_param.warp_n, | ||||
m_algo_param.warp_k}, | m_algo_param.warp_k}, | ||||
stream); | |||||
m_algo_param.stage, stream); | |||||
} else { | } else { | ||||
megdnn_assert(param.format == Format::NCHW32_NCHW4); | megdnn_assert(param.format == Format::NCHW32_NCHW4); | ||||
cutlass_wrapper:: | cutlass_wrapper:: | ||||
@@ -202,7 +194,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | ||||
m_algo_param.warp_n, | m_algo_param.warp_n, | ||||
m_algo_param.warp_k}, | m_algo_param.warp_k}, | ||||
stream); | |||||
m_algo_param.stage, stream); | |||||
} | } | ||||
} else { | } else { | ||||
if (param.format == Format::NCHW32) { | if (param.format == Format::NCHW32) { | ||||
@@ -218,7 +210,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | ||||
m_algo_param.warp_n, | m_algo_param.warp_n, | ||||
m_algo_param.warp_k}, | m_algo_param.warp_k}, | ||||
stream); | |||||
m_algo_param.stage, stream); | |||||
} else { | } else { | ||||
megdnn_assert(param.format == Format::NCHW32_NCHW4); | megdnn_assert(param.format == Format::NCHW32_NCHW4); | ||||
cutlass_wrapper:: | cutlass_wrapper:: | ||||
@@ -238,7 +230,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | ||||
m_algo_param.warp_n, | m_algo_param.warp_n, | ||||
m_algo_param.warp_k}, | m_algo_param.warp_k}, | ||||
stream); | |||||
m_algo_param.stage, stream); | |||||
} | } | ||||
} | } | ||||
after_kernel_launch(); | after_kernel_launch(); | ||||
@@ -246,9 +238,10 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
std::string ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::to_string( | std::string ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::to_string( | ||||
AlgoParam algo_param) { | AlgoParam algo_param) { | ||||
return ssprintf("%uX%uX%u_%uX%uX%u", algo_param.threadblock_m, | |||||
return ssprintf("%uX%uX%u_%uX%uX%u_%u", algo_param.threadblock_m, | |||||
algo_param.threadblock_n, algo_param.threadblock_k, | algo_param.threadblock_n, algo_param.threadblock_k, | ||||
algo_param.warp_m, algo_param.warp_n, algo_param.warp_k); | |||||
algo_param.warp_m, algo_param.warp_n, algo_param.warp_k, | |||||
algo_param.stage); | |||||
} | } | ||||
size_t ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm:: | size_t ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm:: | ||||
@@ -267,36 +260,26 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec_preprocess( | |||||
using Format = Param::Format; | using Format = Param::Format; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
size_t n = args.src_layout->operator[](0), | |||||
ci = args.src_layout->operator[](1) * 32, | |||||
hi = args.src_layout->operator[](2), | |||||
wi = args.src_layout->operator[](3); | |||||
size_t ho = args.dst_layout->operator[](2), | |||||
wo = args.dst_layout->operator[](3); | |||||
size_t ci = args.src_layout->operator[](1) * 32; | |||||
size_t co; | size_t co; | ||||
bool trans_oc; | |||||
if (param.format == Format::NCHW32) { | if (param.format == Format::NCHW32) { | ||||
co = args.dst_layout->operator[](1) * 32; | co = args.dst_layout->operator[](1) * 32; | ||||
trans_oc = true; | |||||
} else { | } else { | ||||
megdnn_assert(param.format == Format::NCHW32_NCHW4); | megdnn_assert(param.format == Format::NCHW32_NCHW4); | ||||
co = args.dst_layout->operator[](1) * 4; | co = args.dst_layout->operator[](1) * 4; | ||||
trans_oc = false; | |||||
} | } | ||||
UNPACK_CONV_PARAMETER(fm, param); | |||||
MARK_USED_VAR | |||||
TensorLayout src{{co, ci / 32, fh, fw, 32}, dtype::Int8()}; | |||||
src.init_contiguous_stride(); | |||||
TensorLayout dst = src; | |||||
dst.stride[0] = 32; | |||||
dst.stride[1] = co * fh * fw * 32; | |||||
dst.stride[2] = co * fw * 32; | |||||
dst.stride[3] = co * 32; | |||||
dst.stride[4] = 1; | |||||
TensorND ts_src, ts_dst; | |||||
ts_src.raw_ptr = args.filter_tensor->raw_ptr; | |||||
ts_src.layout = src; | |||||
ts_dst.raw_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
ts_dst.layout = dst; | |||||
auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | |||||
transpose->exec(ts_src, ts_dst); | |||||
size_t fh = fm.spatial[0], fw = fm.spatial[1]; | |||||
cudaStream_t stream = cuda_stream(args.opr->handle()); | |||||
// filter: KCRS32 => CRSK32 and reorder oc | |||||
cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>( | |||||
reinterpret_cast<int8_t*>( | |||||
args.preprocessed_filter->tensors[0].raw_ptr), | |||||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, fh, | |||||
fw, trans_oc, stream); | |||||
} | } | ||||
#endif | #endif | ||||
@@ -144,7 +144,8 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::do_exec( | |||||
reinterpret_cast<uint8_t*>(z_ptr), | reinterpret_cast<uint8_t*>(z_ptr), | ||||
reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | ||||
kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | ||||
dst_scale, src_zero, threadblock_shape, warp_shape, stream); | |||||
dst_scale, src_zero, threadblock_shape, warp_shape, | |||||
m_algo_param.stage, stream); | |||||
} | } | ||||
void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::update_bias( | void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::update_bias( | ||||
@@ -147,7 +147,7 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||||
reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | ||||
kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | ||||
dst_scale, src_zero, threadblock_shape, warp_shape, | dst_scale, src_zero, threadblock_shape, warp_shape, | ||||
m_algo_param.access_size, stream); | |||||
m_algo_param.access_size, m_algo_param.stage, stream); | |||||
} else { | } else { | ||||
cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc<true>( | cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc<true>( | ||||
reinterpret_cast<uint8_t*>(args.src_tensor->raw_ptr), | reinterpret_cast<uint8_t*>(args.src_tensor->raw_ptr), | ||||
@@ -157,7 +157,7 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||||
reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | ||||
kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | ||||
dst_scale, src_zero, threadblock_shape, warp_shape, | dst_scale, src_zero, threadblock_shape, warp_shape, | ||||
m_algo_param.access_size, stream); | |||||
m_algo_param.access_size, m_algo_param.stage, stream); | |||||
} | } | ||||
} | } | ||||
@@ -840,21 +840,21 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_IMMA) { | |||||
param.pad_h = param.pad_w = 1; | param.pad_h = param.pad_w = 1; | ||||
param.stride_h = param.stride_w = 1; | param.stride_h = param.stride_w = 1; | ||||
param.format = param::ConvBias::Format::NCHW32; | param.format = param::ConvBias::Format::NCHW32; | ||||
checker.set_param(param).execs({{16, 16, 7, 7, 32}, | |||||
{512, 16, 3, 3, 32}, | |||||
{1, 16, 1, 1, 32}, | |||||
checker.set_param(param).execs({{16, 8, 7, 7, 32}, | |||||
{256, 8, 3, 3, 32}, | |||||
{1, 8, 1, 1, 32}, | |||||
{}, | {}, | ||||
{}}); | {}}); | ||||
param.nonlineMode = param::ConvBias::NonlineMode::RELU; | param.nonlineMode = param::ConvBias::NonlineMode::RELU; | ||||
checker.set_param(param).execs({{16, 16, 7, 7, 32}, | |||||
{512, 16, 1, 1, 32}, | |||||
{1, 16, 1, 1, 32}, | |||||
checker.set_param(param).execs({{16, 8, 7, 7, 32}, | |||||
{256, 8, 1, 1, 32}, | |||||
{1, 8, 1, 1, 32}, | |||||
{}, | {}, | ||||
{}}); | {}}); | ||||
param.nonlineMode = param::ConvBias::NonlineMode::H_SWISH; | param.nonlineMode = param::ConvBias::NonlineMode::H_SWISH; | ||||
checker.set_param(param).execs({{16, 16, 7, 7, 32}, | |||||
{512, 16, 3, 3, 32}, | |||||
{1, 16, 1, 1, 32}, | |||||
checker.set_param(param).execs({{16, 8, 7, 7, 32}, | |||||
{256, 8, 3, 3, 32}, | |||||
{1, 8, 1, 1, 32}, | |||||
{}, | {}, | ||||
{}}); | {}}); | ||||
// use non integer scale | // use non integer scale | ||||
@@ -867,18 +867,18 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_IMMA) { | |||||
.set_epsilon(1 + 1e-3) | .set_epsilon(1 + 1e-3) | ||||
.set_max_avg_error(1e-1) | .set_max_avg_error(1e-1) | ||||
.set_max_avg_biased_error(1e-1) | .set_max_avg_biased_error(1e-1) | ||||
.execs({{16, 16, 7, 7, 32}, | |||||
{512, 16, 3, 3, 32}, | |||||
{1, 16, 1, 1, 32}, | |||||
{16, 16, 7, 7, 32}, | |||||
.execs({{16, 8, 7, 7, 32}, | |||||
{256, 8, 3, 3, 32}, | |||||
{1, 8, 1, 1, 32}, | |||||
{16, 8, 7, 7, 32}, | |||||
{}}); | {}}); | ||||
}; | }; | ||||
std::string algo = ConvBias::algo_name<ConvBias::DirectParam>( | std::string algo = ConvBias::algo_name<ConvBias::DirectParam>( | ||||
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_256X128X64_64X64X64", | |||||
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_128X128X64_64X64X64_2", | |||||
ConvBias::DirectParam{}); | ConvBias::DirectParam{}); | ||||
check(algo); | check(algo); | ||||
algo = ConvBias::algo_name<ConvBias::DirectParam>( | algo = ConvBias::algo_name<ConvBias::DirectParam>( | ||||
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_32X64X64_32X16X64", | |||||
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_128X32X32_64X32X32_1", | |||||
ConvBias::DirectParam{}); | ConvBias::DirectParam{}); | ||||
check(algo); | check(algo); | ||||
} | } | ||||
@@ -969,7 +969,7 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_NCHW4) { | |||||
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker< | checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker< | ||||
ConvBiasForward>( | ConvBiasForward>( | ||||
ConvBias::algo_name<ConvBias::DirectParam>( | ConvBias::algo_name<ConvBias::DirectParam>( | ||||
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_256X128X64_64X64X64", | |||||
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_128X128X64_64X64X64_2", | |||||
ConvBias::DirectParam{}) | ConvBias::DirectParam{}) | ||||
.c_str())); | .c_str())); | ||||
checker.set_dtype(0, dtype::QuantizedS8(1.9980618f)) | checker.set_dtype(0, dtype::QuantizedS8(1.9980618f)) | ||||