GitOrigin-RevId: 0b5574f526
tags/v1.6.0-rc1
@@ -20,7 +20,7 @@ class Conv2dOperation: | |||
# | |||
def __init__(self, conv_kind, conv_type, arch, tile_description, src, flt, bias, dst, element_epilogue, \ | |||
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4, \ | |||
need_load_from_const = True, implicit_gemm_mode = ImplicitGemmMode.GemmNt): | |||
need_load_from_const = True, implicit_gemm_mode = ImplicitGemmMode.GemmNT, without_shared_load = False): | |||
self.operation_kind = OperationKind.Conv2d | |||
self.conv_kind = conv_kind | |||
@@ -36,6 +36,7 @@ class Conv2dOperation: | |||
self.swizzling_functor = swizzling_functor | |||
self.need_load_from_const = need_load_from_const | |||
self.implicit_gemm_mode = implicit_gemm_mode | |||
self.without_shared_load = without_shared_load | |||
# | |||
def accumulator_type(self): | |||
accum = self.tile_description.math_instruction.element_accumulator | |||
@@ -58,11 +59,15 @@ class Conv2dOperation: | |||
unity_kernel = '' | |||
if not self.need_load_from_const: | |||
unity_kernel = '_1x1' | |||
unity_kernel = '_1x1' | |||
return "%s%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \ | |||
reorder_k = '' | |||
if self.without_shared_load: | |||
reorder_k = '_roc' | |||
return "%s%s%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \ | |||
inst_shape, intermediate_type, ConvKindNames[self.conv_kind], unity_kernel, \ | |||
ShortEpilogueNames[self.epilogue_functor]) | |||
reorder_k, ShortEpilogueNames[self.epilogue_functor]) | |||
# | |||
def extended_name(self): | |||
@@ -177,7 +182,8 @@ using Convolution = | |||
${alignment_filter}, | |||
${nonuninity_kernel}, | |||
${math_operator}, | |||
${implicit_gemm_mode}>; | |||
${implicit_gemm_mode}, | |||
${without_shared_load}>; | |||
""" | |||
@@ -219,7 +225,8 @@ using Convolution = | |||
'alignment_filter': str(operation.flt.alignment), | |||
'nonuninity_kernel': str(operation.need_load_from_const).lower(), | |||
'math_operator': MathOperationTag[operation.tile_description.math_instruction.math_operation], | |||
'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode] | |||
'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode], | |||
'without_shared_load': str(operation.without_shared_load).lower() | |||
} | |||
return SubstituteTemplate(self.template, values) | |||
@@ -312,13 +319,13 @@ using Deconvolution = | |||
# | |||
def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_layout, dst_type, min_cc, src_align = 32, flt_align = 32, dst_align = 128, \ | |||
skip_unity_kernel = False, implicit_gemm_mode = ImplicitGemmMode.GemmNt): | |||
skip_unity_kernel = False, implicit_gemm_mode = ImplicitGemmMode.GemmNT, without_shared_load = False): | |||
operations = [] | |||
element_epilogue = DataType.f32 | |||
if conv_kind == ConvKind.Fprop: | |||
if src_layout == LayoutType.TensorNHWC: | |||
swizzling_functor = SwizzlingFunctor.ConvFpropNHWC | |||
if implicit_gemm_mode == ImplicitGemmMode.GemmTN: | |||
swizzling_functor = SwizzlingFunctor.ConvFpropTrans | |||
else: | |||
swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx | |||
else: | |||
@@ -399,10 +406,10 @@ def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_lay | |||
bias = TensorDescription(bias_type, dst_layout, max(1, int(32 / DataTypeSize[bias_type]))) | |||
dst = TensorDescription(dst_type, dst_layout, int(dst_align / DataTypeSize[dst_type])) | |||
new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, True, implicit_gemm_mode) | |||
new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, True, implicit_gemm_mode, without_shared_load) | |||
operations.append(new_operation) | |||
if not skip_unity_kernel: | |||
new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, False, implicit_gemm_mode) | |||
new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, False, implicit_gemm_mode, without_shared_load) | |||
operations.append(new_operation) | |||
return operations | |||
@@ -175,12 +175,10 @@ def GenerateConv2d_Simt(args): | |||
TileDescription([128, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([ 64, 128, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc), | |||
TileDescription([ 64, 64, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), | |||
TileDescription([ 32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([ 32, 64, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), | |||
TileDescription([ 64, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), | |||
TileDescription([ 32, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), | |||
TileDescription([ 16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc), | |||
TileDescription([ 16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), | |||
] | |||
@@ -223,28 +221,36 @@ def GenerateConv2d_TensorOp_8816(args): | |||
for dst_type, dst_layout in zip(dst_types, dst_layouts): | |||
if dst_layout == LayoutType.TensorNC32HW32: | |||
tile_descriptions = [ | |||
TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), | |||
TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([ 32, 64, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc), | |||
TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||
TileDescription([ 64, 128, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([ 32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc), | |||
] | |||
operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | |||
dst_layout, dst_type, min_cc, 128, 128, 64, | |||
False, ImplicitGemmMode.GemmTN, True) | |||
else: | |||
assert dst_layout == LayoutType.TensorNC4HW4 | |||
tile_descriptions = [ | |||
TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), | |||
TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([ 32, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||
TileDescription([ 64, 128, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([ 32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc), | |||
] | |||
operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | |||
operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | |||
dst_layout, dst_type, min_cc, 128, 128, 64, | |||
False) | |||
return operations | |||
def GenerateConv2d_TensorOp_8832(args): | |||
@@ -279,12 +285,14 @@ def GenerateConv2d_TensorOp_8832(args): | |||
for dst_layout in dst_layouts: | |||
dst_type = math_inst.element_b | |||
tile_descriptions = [ | |||
TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 64, 128], 2, [2, 1, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||
] | |||
operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | |||
dst_layout, dst_type, min_cc, 128, 128, 64, | |||
True) | |||
True, ImplicitGemmMode.GemmTN, True) | |||
layouts_nhwc = [ | |||
(LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32), | |||
@@ -299,14 +307,21 @@ def GenerateConv2d_TensorOp_8832(args): | |||
for math_inst in math_instructions: | |||
for layout in layouts_nhwc: | |||
for dst_layout in dst_layouts_nhwc: | |||
dst_type = math_inst.element_b | |||
tile_descriptions = [ | |||
TileDescription([128, 32, 64], 2, [2, 1, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 64, 64], 2, [2, 1, 1], math_inst, min_cc, max_cc), | |||
] | |||
operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | |||
dst_layout, dst_type, min_cc, layout[2], layout[2], 32, | |||
False, ImplicitGemmMode.GemmTn) | |||
dst_type = math_inst.element_b | |||
tile_descriptions = [ | |||
TileDescription([128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||
] | |||
for tile in tile_descriptions: | |||
operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], | |||
dst_layout, dst_type, min_cc, layout[2], layout[2], 32, | |||
False, ImplicitGemmMode.GemmTN, False) | |||
if tile.threadblock_shape[1] == 32 or tile.threadblock_shape[1] == 64: | |||
dst_align = 32 if tile.threadblock_shape[1] == 32 else 64 | |||
operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], | |||
dst_layout, dst_type, min_cc, layout[2], layout[2], dst_align, | |||
False, ImplicitGemmMode.GemmTN, True) | |||
return operations | |||
def GenerateDeconv_Simt(args): | |||
@@ -649,3 +664,4 @@ if __name__ == "__main__": | |||
# | |||
################################################################################################### | |||
@@ -464,10 +464,10 @@ EpilogueFunctorTag = { | |||
ShortEpilogueNames = { | |||
EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: 'hswish', | |||
EpilogueFunctor.BiasAddLinearCombinationReluClamp: 'relu', | |||
EpilogueFunctor.BiasAddLinearCombinationClamp: 'identity', | |||
EpilogueFunctor.BiasAddLinearCombinationClamp: 'id', | |||
EpilogueFunctor.BiasAddLinearCombinationHSwish: 'hswish', | |||
EpilogueFunctor.BiasAddLinearCombinationRelu: 'relu', | |||
EpilogueFunctor.BiasAddLinearCombination: 'identity', | |||
EpilogueFunctor.BiasAddLinearCombination: 'id', | |||
} | |||
@@ -482,7 +482,7 @@ class SwizzlingFunctor(enum.Enum): | |||
Identity4 = enum_auto() | |||
Identity8 = enum_auto() | |||
ConvFpropNCxHWx = enum_auto() | |||
ConvFpropNHWC = enum_auto() | |||
ConvFpropTrans = enum_auto() | |||
ConvDgradNCxHWx = enum_auto() | |||
# | |||
@@ -492,7 +492,7 @@ SwizzlingFunctorTag = { | |||
SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', | |||
SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', | |||
SwizzlingFunctor.ConvFpropNCxHWx: 'cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle', | |||
SwizzlingFunctor.ConvFpropNHWC: 'cutlass::conv::threadblock::ConvolutionFpropNHWCThreadblockSwizzle', | |||
SwizzlingFunctor.ConvFpropTrans: 'cutlass::conv::threadblock::ConvolutionFpropTransThreadblockSwizzle', | |||
SwizzlingFunctor.ConvDgradNCxHWx: 'cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle', | |||
} | |||
@@ -563,17 +563,17 @@ StrideSupportNames = { | |||
} | |||
class ImplicitGemmMode(enum.Enum): | |||
GemmNt = enum_auto() | |||
GemmTn = enum_auto() | |||
GemmNT = enum_auto() | |||
GemmTN = enum_auto() | |||
ImplicitGemmModeNames = { | |||
ImplicitGemmMode.GemmNt: 'gemm_nt', | |||
ImplicitGemmMode.GemmTn: 'gemm_tn', | |||
ImplicitGemmMode.GemmNT: 'gemm_nt', | |||
ImplicitGemmMode.GemmTN: 'gemm_tn', | |||
} | |||
ImplicitGemmModeTag = { | |||
ImplicitGemmMode.GemmNt: 'cutlass::conv::ImplicitGemmMode::GEMM_NT', | |||
ImplicitGemmMode.GemmTn: 'cutlass::conv::ImplicitGemmMode::GEMM_TN', | |||
ImplicitGemmMode.GemmNT: 'cutlass::conv::ImplicitGemmMode::GEMM_NT', | |||
ImplicitGemmMode.GemmTN: 'cutlass::conv::ImplicitGemmMode::GEMM_TN', | |||
} | |||
################################################################################################### | |||
@@ -164,415 +164,461 @@ cutlass_gen_list = [ | |||
"cutlass_simt_sgemv_batched_strided_1x32_32_tt_align1x4.cu", | |||
"cutlass_simt_sgemv_batched_strided_1x32_16_tt_align1x2.cu", | |||
"cutlass_simt_sgemv_batched_strided_1x32_8_tt_align1x1.cu", | |||
"cutlass_simt_s8_idgrad_identity_s8_64x128x32_64x32x32_2_nc4hw4_k4rsc4.cu", | |||
"cutlass_simt_s8_idgrad_identity_s8_32x128x32_32x64x32_2_nc4hw4_k4rsc4.cu", | |||
"cutlass_simt_s8_idgrad_identity_s8_16x128x16_16x64x16_2_nc4hw4_k4rsc4.cu", | |||
"cutlass_simt_s8_idgrad_identity_s8_16x128x16_16x128x16_1_nc4hw4_k4rsc4.cu", | |||
"cutlass_simt_s8_idgrad_identity_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", | |||
"cutlass_simt_s8_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_idgrad_id_s8_64x128x32_64x32x32_2_nc4hw4_k4rsc4.cu", | |||
"cutlass_simt_s8_idgrad_id_s8_32x128x32_32x64x32_2_nc4hw4_k4rsc4.cu", | |||
"cutlass_simt_s8_idgrad_id_s8_16x128x16_16x64x16_2_nc4hw4_k4rsc4.cu", | |||
"cutlass_simt_s8_idgrad_id_s8_16x128x16_16x128x16_1_nc4hw4_k4rsc4.cu", | |||
"cutlass_simt_s8_idgrad_id_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", | |||
"cutlass_simt_s8_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||
"cutlass_simt_s8_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_s8_ifprop_1x1_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
"cutlass_simt_u4_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_u4_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_s4_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||
"cutlass_simt_f32_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||
"cutlass_tensorop_s8_i8816fprop_identity_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_identity_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_identity_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_identity_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_identity_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_identity_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_relu_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_identity_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_relu_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_identity_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_identity_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_id_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_id_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_id_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_relu_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_id_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_identity_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_id_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_id_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_id_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_id_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_identity_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_identity_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_id_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_id_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_identity_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_relu_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_identity_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_relu_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s4_i8832fprop_identity_s4_256x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_s4_i8832fprop_relu_s4_256x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_256x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_s4_i8832fprop_identity_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_u4_i8832fprop_identity_u4_s4_256x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_256x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_s4_i8832fprop_identity_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_identity_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_identity_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_identity_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_identity_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_identity_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s8_i8816fprop_id_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_id_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_id_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_id_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_relu_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_id_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_id_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_relu_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_id_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_id_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_relu_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_id_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_id_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||
"cutlass_tensorop_s4_i8832fprop_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_id_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_s4_i8832fprop_id_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_id_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_s4_i8832fprop_id_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_id_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
"cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
"cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
] |
@@ -217,56 +217,68 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { | |||
#if CUDA_VERSION >= 10020 | |||
{ | |||
using AlgoParam = AlgoInt8NCHW32IMMAImplicitGemm::AlgoParam; | |||
int8_nchw32_imma.emplace_back(AlgoParam{128, 256, 64, 64, 64, 64}); | |||
int8_nchw32_imma.emplace_back(AlgoParam{256, 128, 64, 64, 64, 64}); | |||
int8_nchw32_imma.emplace_back(AlgoParam{128, 128, 64, 64, 64, 64}); | |||
int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 64, 32, 64, 64}); | |||
int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 64, 64, 32, 64}); | |||
int8_nchw32_imma.emplace_back(AlgoParam{64, 64, 64, 32, 32, 64}); | |||
int8_nchw32_imma.emplace_back(AlgoParam{32, 64, 64, 32, 16, 64}); | |||
int8_nchw32_imma.emplace_back(AlgoParam{128, 256, 64, 64, 64, 64, 2}); | |||
int8_nchw32_imma.emplace_back(AlgoParam{256, 128, 64, 64, 64, 64, 2}); | |||
int8_nchw32_imma.emplace_back(AlgoParam{128, 128, 64, 64, 64, 64, 2}); | |||
int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 64, 64, 32, 64, 2}); | |||
int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 64, 32, 64, 64, 2}); | |||
int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 1}); | |||
int8_nchw32_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1}); | |||
int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 32, 32, 64, 32, 1}); | |||
int8_nchw32_imma.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 1}); | |||
} | |||
{ | |||
using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | |||
int4_int4_nchw64_imma.emplace_back( | |||
AlgoParam{128, 128, 128, 64, 64, 128}); | |||
AlgoParam{128, 128, 128, 64, 64, 128, 2}); | |||
int4_int4_nchw64_imma.emplace_back( | |||
AlgoParam{256, 128, 128, 64, 64, 128}); | |||
AlgoParam{128, 256, 128, 64, 64, 128, 2}); | |||
int4_int4_nchw64_imma.emplace_back( | |||
AlgoParam{128, 64, 128, 64, 64, 128, 2}); | |||
int4_int4_nchw64_imma.emplace_back( | |||
AlgoParam{128, 64, 64, 64, 64, 64, 1}); | |||
} | |||
{ | |||
using AlgoParam = AlgoUInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | |||
uint4_int4_nchw64_imma.emplace_back( | |||
AlgoParam{128, 128, 128, 64, 64, 128}); | |||
AlgoParam{128, 128, 128, 64, 64, 128, 2}); | |||
uint4_int4_nchw64_imma.emplace_back( | |||
AlgoParam{128, 256, 128, 64, 64, 128, 2}); | |||
uint4_int4_nchw64_imma.emplace_back( | |||
AlgoParam{256, 128, 128, 64, 64, 128}); | |||
AlgoParam{128, 64, 128, 64, 64, 128, 2}); | |||
uint4_int4_nchw64_imma.emplace_back( | |||
AlgoParam{128, 64, 64, 64, 64, 64, 1}); | |||
} | |||
{ | |||
using AlgoParam = AlgoInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | |||
int4_int4_nhwc_imma.emplace_back( | |||
AlgoParam{128, 32, 64, 64, 32, 64, 32}); | |||
AlgoParam{128, 32, 64, 64, 32, 64, 1, 32}); | |||
int4_int4_nhwc_imma.emplace_back( | |||
AlgoParam{128, 32, 64, 64, 32, 64, 1, 16}); | |||
int4_int4_nhwc_imma.emplace_back( | |||
AlgoParam{128, 32, 64, 64, 32, 64, 1, 8}); | |||
int4_int4_nhwc_imma.emplace_back( | |||
AlgoParam{128, 32, 64, 64, 32, 64, 16}); | |||
int4_int4_nhwc_imma.emplace_back(AlgoParam{128, 32, 64, 64, 32, 64, 8}); | |||
AlgoParam{128, 64, 64, 64, 64, 64, 1, 32}); | |||
int4_int4_nhwc_imma.emplace_back( | |||
AlgoParam{128, 64, 64, 64, 64, 64, 32}); | |||
AlgoParam{128, 64, 64, 64, 64, 64, 1, 16}); | |||
int4_int4_nhwc_imma.emplace_back( | |||
AlgoParam{128, 64, 64, 64, 64, 64, 16}); | |||
int4_int4_nhwc_imma.emplace_back(AlgoParam{128, 64, 64, 64, 64, 64, 8}); | |||
AlgoParam{128, 64, 64, 64, 64, 64, 1, 8}); | |||
} | |||
{ | |||
using AlgoParam = AlgoUInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | |||
uint4_int4_nhwc_imma.emplace_back( | |||
AlgoParam{128, 32, 64, 64, 32, 64, 32}); | |||
AlgoParam{128, 32, 64, 64, 32, 64, 1, 32}); | |||
uint4_int4_nhwc_imma.emplace_back( | |||
AlgoParam{128, 32, 64, 64, 32, 64, 16}); | |||
AlgoParam{128, 32, 64, 64, 32, 64, 1, 16}); | |||
uint4_int4_nhwc_imma.emplace_back( | |||
AlgoParam{128, 32, 64, 64, 32, 64, 8}); | |||
AlgoParam{128, 32, 64, 64, 32, 64, 1, 8}); | |||
uint4_int4_nhwc_imma.emplace_back( | |||
AlgoParam{128, 64, 64, 64, 64, 64, 32}); | |||
AlgoParam{128, 64, 64, 64, 64, 64, 1, 32}); | |||
uint4_int4_nhwc_imma.emplace_back( | |||
AlgoParam{128, 64, 64, 64, 64, 64, 16}); | |||
AlgoParam{128, 64, 64, 64, 64, 64, 1, 16}); | |||
uint4_int4_nhwc_imma.emplace_back( | |||
AlgoParam{128, 64, 64, 64, 64, 64, 8}); | |||
AlgoParam{128, 64, 64, 64, 64, 64, 1, 8}); | |||
} | |||
#endif | |||
} | |||
@@ -279,10 +291,8 @@ void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() { | |||
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32, 2}); | |||
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 2}); | |||
int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 2}); | |||
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 64, 32, 64, 32, 32, 2}); | |||
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32, 2}); | |||
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32, 2}); | |||
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 32, 32, 32, 32, 32, 2}); | |||
int8_nchw4_dotprod.emplace_back(AlgoParam{16, 128, 16, 16, 128, 16, 1}); | |||
int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 2}); | |||
} | |||
@@ -723,6 +723,7 @@ public: | |||
int warp_m; | |||
int warp_n; | |||
int warp_k; | |||
int stage; | |||
}; | |||
AlgoInt8NCHW32IMMAImplicitGemm(AlgoParam algo_param) | |||
: m_algo_param{algo_param} { | |||
@@ -770,6 +771,7 @@ public: | |||
int warp_m; | |||
int warp_n; | |||
int warp_k; | |||
int stage; | |||
}; | |||
AlgoInt4NCHW64IMMAImplicitGemmBase(AlgoParam algo_param) | |||
@@ -897,6 +899,7 @@ public: | |||
int warp_m; | |||
int warp_n; | |||
int warp_k; | |||
int stage; | |||
int access_size; | |||
}; | |||
@@ -38,7 +38,7 @@ void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( | |||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
float alpha, float beta, float gamma, float scale, | |||
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||
cudaStream_t stream); | |||
int stages, cudaStream_t stream); | |||
template <bool NeedLoadFromConstMem> | |||
void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( | |||
@@ -47,7 +47,7 @@ void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( | |||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
float alpha, float beta, float gamma, float scale, | |||
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||
cudaStream_t stream); | |||
int stages, cudaStream_t stream); | |||
template <bool NeedLoadFromConstMem> | |||
void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||
@@ -83,7 +83,7 @@ void do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
float alpha, float beta, float gamma, float scale, | |||
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||
cudaStream_t stream); | |||
int stages, cudaStream_t stream); | |||
template <bool NeedLoadFromConstMem> | |||
void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
@@ -92,7 +92,7 @@ void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
float alpha, float beta, float gamma, float delta, float theta, | |||
float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||
const GemmCoord& warp_shape, cudaStream_t stream); | |||
const GemmCoord& warp_shape, int stages, cudaStream_t stream); | |||
template <bool signedness> | |||
void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( | |||
@@ -110,7 +110,7 @@ void do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
float alpha, float beta, float gamma, float scale, | |||
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||
const int32_t access_size, cudaStream_t stream); | |||
const int32_t access_size, int stages, cudaStream_t stream); | |||
template <bool NeedLoadFromConstMem> | |||
void do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||
@@ -119,7 +119,7 @@ void do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
float alpha, float beta, float gamma, float delta, float theta, | |||
float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||
const GemmCoord& warp_shape, const int32_t access_size, | |||
const GemmCoord& warp_shape, const int32_t access_size, int stages, | |||
cudaStream_t stream); | |||
} // namespace cutlass_wrapper | |||
@@ -0,0 +1,595 @@ | |||
/** | |||
* \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
// ignore warning of cutlass | |||
#pragma GCC diagnostic push | |||
#pragma GCC diagnostic ignored "-Wunused-parameter" | |||
#pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
#if !MEGDNN_TEGRA_X1 | |||
#include "cutlass/convolution/device/convolution.h" | |||
#endif | |||
#include "src/common/opr_param_defs_enumv.cuh" | |||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
#pragma GCC diagnostic pop | |||
using namespace megdnn; | |||
using namespace cuda; | |||
using namespace cutlass_wrapper; | |||
/* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */ | |||
#if MEGDNN_TEGRA_X1 | |||
template <bool NeedLoadFromConstMem> | |||
void megdnn::cuda::cutlass_wrapper:: | |||
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||
const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||
int8_t* /* d_dst */, int* /* workspace */, | |||
const convolution::ConvParam& /* param */, | |||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||
float /* beta */, float /* gamma */, float /* scale */, | |||
const GemmCoord& /* threadblock_shape */, | |||
const GemmCoord& /* warp_shape */, int /* stages */, | |||
cudaStream_t /* stream */) {} | |||
#else | |||
template <bool NeedLoadFromConstMem> | |||
void megdnn::cuda::cutlass_wrapper:: | |||
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
const int8_t* d_src, const int8_t* d_filter, | |||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||
int* workspace, const convolution::ConvParam& param, | |||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
float scale, const GemmCoord& threadblock_shape, | |||
const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
threadblock_k_, warp_m_, warp_n_, \ | |||
warp_k_, stage_) \ | |||
if (threadblock_shape.m() == threadblock_m_ && \ | |||
threadblock_shape.n() == threadblock_n_ && \ | |||
threadblock_shape.k() == threadblock_k_ && \ | |||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
warp_shape.k() == warp_k_ && stages == stage_) { \ | |||
using ThreadBlockShape = \ | |||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
threadblock_k_>; \ | |||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||
using Convolution = cutlass::conv::device::Convolution< \ | |||
cutlass::int4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||
cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||
ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||
cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||
cutlass::conv::ConvType::kConvolution, \ | |||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
cutlass::conv::threadblock:: \ | |||
ConvolutionFpropTransThreadblockSwizzle, \ | |||
stage_, 32, 32, NeedLoadFromConstMem, \ | |||
cutlass::arch::OpMultiplyAddSaturate, \ | |||
cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ | |||
typename Convolution::ConvolutionParameter conv_param( \ | |||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
return cutlass_convolution_wrapper<Convolution>( \ | |||
reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||
reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||
reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, \ | |||
conv_param, epilogue, stream); \ | |||
} | |||
#define DISPATCH_KERNEL \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128, 2); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 128, 64, 64, 128, 2); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 128, 64, 64, 128, 2); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1); \ | |||
megdnn_assert(false, \ | |||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
"(%dx%dx%d)", \ | |||
threadblock_shape.m(), threadblock_shape.n(), \ | |||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
warp_shape.k()); | |||
using ElementOutput = cutlass::int4b_t; | |||
using ElementAccumulator = int32_t; | |||
using ElementBias = int32_t; | |||
using ElementCompute = float; | |||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||
switch (nonlinear_mode) { | |||
case NonlineMode::IDENTITY: { | |||
using EpilogueOp = | |||
cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||
ElementCompute>; | |||
typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||
DISPATCH_KERNEL; | |||
} | |||
case NonlineMode::RELU: { | |||
using EpilogueOp = cutlass::epilogue::thread:: | |||
BiasAddLinearCombinationReluClamp< | |||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||
ElementCompute>; | |||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||
DISPATCH_KERNEL; | |||
} | |||
case NonlineMode::H_SWISH: { | |||
using EpilogueOp = cutlass::epilogue::thread:: | |||
BiasAddLinearCombinationHSwishClamp< | |||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||
ElementCompute>; | |||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||
DISPATCH_KERNEL; | |||
} | |||
default: | |||
megdnn_assert(false, | |||
"unsupported nonlinear mode for conv bias operator"); | |||
} | |||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
#undef DISPATCH_KERNEL | |||
} | |||
#endif | |||
#define INST(need_load_from_const_mem) \ | |||
template void megdnn::cuda::cutlass_wrapper:: \ | |||
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||
need_load_from_const_mem>( \ | |||
const int8_t* d_src, const int8_t* d_filter, \ | |||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||
int* workspace, const convolution::ConvParam& param, \ | |||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||
float gamma, float scale, \ | |||
const GemmCoord& threadblock_shape, \ | |||
const GemmCoord& warp_shape, int stages, \ | |||
cudaStream_t stream); | |||
INST(true); | |||
#undef INST | |||
/* ====== cutlass kernel wrapper for uint4 x int4 nchw64 layout ====== */ | |||
#if MEGDNN_TEGRA_X1 | |||
template <bool NeedLoadFromConstMem> | |||
void megdnn::cuda::cutlass_wrapper:: | |||
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||
const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||
uint8_t* /* d_dst */, int* /* workspace */, | |||
const convolution::ConvParam& /* param */, | |||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||
float /* beta */, float /* gamma */, float /* delta */, | |||
float /* theta */, float /* scale */, | |||
uint8_t /* src_zero_point */, | |||
const GemmCoord& /* threadblock_shape */, | |||
const GemmCoord& /* warp_shape */, int /* stages */, | |||
cudaStream_t /* stream */) {} | |||
#else | |||
template <bool NeedLoadFromConstMem> | |||
void megdnn::cuda::cutlass_wrapper:: | |||
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
const uint8_t* d_src, const int8_t* d_filter, | |||
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||
int* workspace, const convolution::ConvParam& param, | |||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
float delta, float theta, float /* scale */, | |||
uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||
const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
threadblock_k_, warp_m_, warp_n_, \ | |||
warp_k_, stage_) \ | |||
if (threadblock_shape.m() == threadblock_m_ && \ | |||
threadblock_shape.n() == threadblock_n_ && \ | |||
threadblock_shape.k() == threadblock_k_ && \ | |||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
warp_shape.k() == warp_k_ && stages == stage_) { \ | |||
using ThreadBlockShape = \ | |||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
threadblock_k_>; \ | |||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||
using Convolution = cutlass::conv::device::Convolution< \ | |||
cutlass::uint4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||
cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||
ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||
cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||
cutlass::conv::ConvType::kConvolution, \ | |||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
cutlass::conv::threadblock:: \ | |||
ConvolutionFpropTransThreadblockSwizzle, \ | |||
stage_, 32, 32, NeedLoadFromConstMem, \ | |||
cutlass::arch::OpMultiplyAddSaturate, \ | |||
cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ | |||
typename Convolution::ConvolutionParameter conv_param( \ | |||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
return cutlass_convolution_wrapper<Convolution>( \ | |||
reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||
reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||
reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||
conv_param, epilogue, stream, {src_zero_point}); \ | |||
} | |||
#define DISPATCH_KERNEL \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128, 2); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 128, 64, 64, 128, 2); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 128, 64, 64, 128, 2); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1); \ | |||
megdnn_assert(false, \ | |||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
"(%dx%dx%d)", \ | |||
threadblock_shape.m(), threadblock_shape.n(), \ | |||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
warp_shape.k()); | |||
using ElementOutput = cutlass::uint4b_t; | |||
using ElementAccumulator = int32_t; | |||
using ElementBias = int32_t; | |||
using ElementCompute = float; | |||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||
switch (nonlinear_mode) { | |||
case NonlineMode::IDENTITY: { | |||
using EpilogueOp = | |||
cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||
ElementCompute>; | |||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||
delta + theta}; | |||
DISPATCH_KERNEL; | |||
} | |||
case NonlineMode::RELU: { | |||
using EpilogueOp = cutlass::epilogue::thread:: | |||
BiasAddLinearCombinationReluClamp< | |||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||
ElementCompute>; | |||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||
0, delta, theta}; | |||
DISPATCH_KERNEL; | |||
} | |||
default: | |||
megdnn_assert(false, | |||
"unsupported nonlinear mode for conv bias operator"); | |||
} | |||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
#undef DISPATCH_KERNEL | |||
} | |||
#endif | |||
#define INST(need_load_from_const_mem) \ | |||
template void megdnn::cuda::cutlass_wrapper:: \ | |||
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||
need_load_from_const_mem>( \ | |||
const uint8_t* d_src, const int8_t* d_filter, \ | |||
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||
int* workspace, const convolution::ConvParam& param, \ | |||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||
float gamma, float delta, float theta, float scale, \ | |||
uint8_t src_zero_point, \ | |||
const GemmCoord& threadblock_shape, \ | |||
const GemmCoord& warp_shape, int stages, \ | |||
cudaStream_t stream); | |||
INST(true); | |||
#undef INST | |||
/* ====== cutlass kernel wrapper for int4 x int4 nhwc layout ====== */ | |||
#if MEGDNN_TEGRA_X1 | |||
template <bool NeedLoadFromConstMem> | |||
void megdnn::cuda::cutlass_wrapper:: | |||
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||
const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||
const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||
int8_t* /* d_dst */, int* /* workspace */, | |||
const convolution::ConvParam& /* param */, | |||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||
float /* beta */, float /* gamma */, float /* scale */, | |||
const GemmCoord& /* threadblock_shape */, | |||
const GemmCoord& /* warp_shape */, | |||
const int32_t /* access_size */, int /* stages */, | |||
cudaStream_t /* stream */) {} | |||
#else | |||
template <bool NeedLoadFromConstMem> | |||
void megdnn::cuda::cutlass_wrapper:: | |||
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||
const int8_t* d_src, const int8_t* d_filter, | |||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||
int* workspace, const convolution::ConvParam& param, | |||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
float scale, const GemmCoord& threadblock_shape, | |||
const GemmCoord& warp_shape, const int32_t access_size, | |||
int stages, cudaStream_t stream) { | |||
bool without_shared_load = | |||
((param.co % threadblock_shape.n() == 0) && | |||
(threadblock_shape.n() == 32 || threadblock_shape.n() == 64)); | |||
int out_elements_per_access = | |||
without_shared_load ? threadblock_shape.n() / 4 : 8; | |||
#define RUN_CUTLASS_WRAPPER(stage_, access_size_, without_shared_load_) \ | |||
using Convolution = cutlass::conv::device::Convolution< \ | |||
cutlass::int4b_t, cutlass::layout::TensorNHWC, cutlass::int4b_t, \ | |||
cutlass::layout::TensorNCxHWx<access_size_>, ElementOutput, \ | |||
cutlass::layout::TensorNHWC, int32_t, cutlass::layout::TensorNHWC, \ | |||
int32_t, cutlass::conv::ConvType::kConvolution, \ | |||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
cutlass::conv::threadblock:: \ | |||
ConvolutionFpropTransThreadblockSwizzle, \ | |||
stage_, access_size_, access_size_, NeedLoadFromConstMem, \ | |||
cutlass::arch::OpMultiplyAddSaturate, \ | |||
cutlass::conv::ImplicitGemmMode::GEMM_TN, without_shared_load_>; \ | |||
typename Convolution::ConvolutionParameter conv_param( \ | |||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
return cutlass_convolution_wrapper<Convolution>( \ | |||
reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||
reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||
reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, conv_param, \ | |||
epilogue, stream); | |||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE( \ | |||
threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ | |||
warp_k_, stage_, access_size_, out_elements_per_access_, \ | |||
without_shared_load_) \ | |||
if (threadblock_shape.m() == threadblock_m_ && \ | |||
threadblock_shape.n() == threadblock_n_ && \ | |||
threadblock_shape.k() == threadblock_k_ && \ | |||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
warp_shape.k() == warp_k_ && stages == stage_ && \ | |||
access_size == access_size_ && \ | |||
out_elements_per_access == out_elements_per_access_ && \ | |||
without_shared_load == without_shared_load_) { \ | |||
using ThreadBlockShape = \ | |||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
threadblock_k_>; \ | |||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||
using ElementOutput = cutlass::int4b_t; \ | |||
using ElementAccumulator = int32_t; \ | |||
using ElementBias = int32_t; \ | |||
using ElementCompute = float; \ | |||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; \ | |||
switch (nonlinear_mode) { \ | |||
case NonlineMode::IDENTITY: { \ | |||
using EpilogueOp = cutlass::epilogue::thread:: \ | |||
BiasAddLinearCombinationClamp< \ | |||
ElementOutput, out_elements_per_access_, \ | |||
ElementAccumulator, ElementBias, \ | |||
ElementCompute>; \ | |||
typename EpilogueOp::Params epilogue{alpha, beta, gamma}; \ | |||
RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||
without_shared_load_); \ | |||
} \ | |||
case NonlineMode::RELU: { \ | |||
using EpilogueOp = cutlass::epilogue::thread:: \ | |||
BiasAddLinearCombinationReluClamp< \ | |||
ElementOutput, out_elements_per_access_, \ | |||
ElementAccumulator, ElementBias, \ | |||
ElementCompute>; \ | |||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; \ | |||
RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||
without_shared_load_); \ | |||
} \ | |||
case NonlineMode::H_SWISH: { \ | |||
using EpilogueOp = cutlass::epilogue::thread:: \ | |||
BiasAddLinearCombinationHSwishClamp< \ | |||
ElementOutput, out_elements_per_access_, \ | |||
ElementAccumulator, ElementBias, \ | |||
ElementCompute>; \ | |||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ | |||
scale}; \ | |||
RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||
without_shared_load_); \ | |||
} \ | |||
default: \ | |||
megdnn_assert( \ | |||
false, \ | |||
"unsupported nonlinear mode for conv bias operator"); \ | |||
} \ | |||
} | |||
#define DISPATCH_KERNEL \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, false); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, false); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, false); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 8, false); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 8, false); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 8, false); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, true); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, true); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, true); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 16, true); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 16, true); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 16, true); \ | |||
megdnn_assert(false, \ | |||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
"(%dx%dx%d) and access_size (%d)", \ | |||
threadblock_shape.m(), threadblock_shape.n(), \ | |||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
warp_shape.k(), access_size); | |||
DISPATCH_KERNEL; | |||
#undef RUN_CUTLASS_WRAPPER | |||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
#undef DISPATCH_KERNEL | |||
} | |||
#endif | |||
#define INST(need_load_from_const_mem) \ | |||
template void megdnn::cuda::cutlass_wrapper:: \ | |||
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc< \ | |||
need_load_from_const_mem>( \ | |||
const int8_t* d_src, const int8_t* d_filter, \ | |||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||
int* workspace, const convolution::ConvParam& param, \ | |||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||
float gamma, float scale, \ | |||
const GemmCoord& threadblock_shape, \ | |||
const GemmCoord& warp_shape, const int32_t access_size, \ | |||
int stages, cudaStream_t stream); | |||
INST(true); | |||
INST(false); | |||
#undef INST | |||
/* ====== cutlass kernel wrapper for uint4 x int4 nhwc layout ====== */ | |||
#if MEGDNN_TEGRA_X1 | |||
template <bool NeedLoadFromConstMem> | |||
void megdnn::cuda::cutlass_wrapper:: | |||
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||
const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||
const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||
uint8_t* /* d_dst */, int* /* workspace */, | |||
const convolution::ConvParam& /* param */, | |||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||
float /* beta */, float /* gamma */, float /* delta */, | |||
float /* theta */, float /* scale */, | |||
uint8_t /* src_zero_point */, | |||
const GemmCoord& /* threadblock_shape */, | |||
const GemmCoord& /* warp_shape */, | |||
const int32_t /* access_size */, int /* stages */, | |||
cudaStream_t /* stream */) {} | |||
#else | |||
template <bool NeedLoadFromConstMem> | |||
void megdnn::cuda::cutlass_wrapper:: | |||
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||
const uint8_t* d_src, const int8_t* d_filter, | |||
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||
int* workspace, const convolution::ConvParam& param, | |||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
float delta, float theta, float /* scale */, | |||
uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||
const GemmCoord& warp_shape, const int32_t access_size, | |||
int stages, cudaStream_t stream) { | |||
bool without_shared_load = | |||
((param.co % threadblock_shape.n() == 0) && | |||
(threadblock_shape.n() == 32 || threadblock_shape.n() == 64)); | |||
int out_elements_per_access = | |||
without_shared_load ? threadblock_shape.n() / 4 : 8; | |||
#define RUN_CUTLASS_WRAPPER(stage_, access_size_, without_shared_load_) \ | |||
using Convolution = cutlass::conv::device::Convolution< \ | |||
cutlass::uint4b_t, cutlass::layout::TensorNHWC, cutlass::int4b_t, \ | |||
cutlass::layout::TensorNCxHWx<access_size_>, ElementOutput, \ | |||
cutlass::layout::TensorNHWC, int32_t, cutlass::layout::TensorNHWC, \ | |||
int32_t, cutlass::conv::ConvType::kConvolution, \ | |||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
cutlass::conv::threadblock:: \ | |||
ConvolutionFpropTransThreadblockSwizzle, \ | |||
stage_, access_size_, access_size_, NeedLoadFromConstMem, \ | |||
cutlass::arch::OpMultiplyAddSaturate, \ | |||
cutlass::conv::ImplicitGemmMode::GEMM_TN, without_shared_load_>; \ | |||
typename Convolution::ConvolutionParameter conv_param( \ | |||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
return cutlass_convolution_wrapper<Convolution>( \ | |||
reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||
reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||
reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||
conv_param, epilogue, stream, {src_zero_point}); | |||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE( \ | |||
threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ | |||
warp_k_, stage_, access_size_, out_elements_per_access_, \ | |||
without_shared_load_) \ | |||
if (threadblock_shape.m() == threadblock_m_ && \ | |||
threadblock_shape.n() == threadblock_n_ && \ | |||
threadblock_shape.k() == threadblock_k_ && \ | |||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
warp_shape.k() == warp_k_ && stages == stage_ && \ | |||
access_size == access_size_ && \ | |||
out_elements_per_access == out_elements_per_access_ && \ | |||
without_shared_load == without_shared_load_) { \ | |||
using ThreadBlockShape = \ | |||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
threadblock_k_>; \ | |||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||
using ElementOutput = cutlass::uint4b_t; \ | |||
using ElementAccumulator = int32_t; \ | |||
using ElementBias = int32_t; \ | |||
using ElementCompute = float; \ | |||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; \ | |||
switch (nonlinear_mode) { \ | |||
case NonlineMode::IDENTITY: { \ | |||
using EpilogueOp = cutlass::epilogue::thread:: \ | |||
BiasAddLinearCombinationClamp< \ | |||
ElementOutput, out_elements_per_access_, \ | |||
ElementAccumulator, ElementBias, \ | |||
ElementCompute>; \ | |||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ | |||
delta + theta}; \ | |||
RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||
without_shared_load_); \ | |||
} \ | |||
case NonlineMode::RELU: { \ | |||
using EpilogueOp = cutlass::epilogue::thread:: \ | |||
BiasAddLinearCombinationReluClamp< \ | |||
ElementOutput, out_elements_per_access_, \ | |||
ElementAccumulator, ElementBias, \ | |||
ElementCompute>; \ | |||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ | |||
0, delta, theta}; \ | |||
RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||
without_shared_load_); \ | |||
} \ | |||
default: \ | |||
megdnn_assert( \ | |||
false, \ | |||
"unsupported nonlinear mode for conv bias operator"); \ | |||
} \ | |||
} | |||
#define DISPATCH_KERNEL \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, false); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, false); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, false); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 8, false); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 8, false); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 8, false); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, true); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, true); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, true); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 16, true); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 16, true); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 16, true); \ | |||
megdnn_assert(false, \ | |||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
"(%dx%dx%d) and access_size (%d)", \ | |||
threadblock_shape.m(), threadblock_shape.n(), \ | |||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
warp_shape.k(), access_size); | |||
DISPATCH_KERNEL; | |||
#undef RUN_CUTLASS_WRAPPER | |||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
#undef DISPATCH_KERNEL | |||
} | |||
#endif | |||
#define INST(need_load_from_const_mem) \ | |||
template void megdnn::cuda::cutlass_wrapper:: \ | |||
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc< \ | |||
need_load_from_const_mem>( \ | |||
const uint8_t* d_src, const int8_t* d_filter, \ | |||
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||
int* workspace, const convolution::ConvParam& param, \ | |||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||
float gamma, float delta, float theta, float scale, \ | |||
uint8_t src_zero_point, \ | |||
const GemmCoord& threadblock_shape, \ | |||
const GemmCoord& warp_shape, const int32_t access_size, \ | |||
int stages, cudaStream_t stream); | |||
INST(true); | |||
INST(false); | |||
#undef INST | |||
// vim: syntax=cuda.doxygen |
@@ -38,7 +38,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||
float /* beta */, float /* gamma */, float /* scale */, | |||
const GemmCoord& /* threadblock_shape */, | |||
const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} | |||
const GemmCoord& /* warp_shape */, int /* stages */, | |||
cudaStream_t /* stream */) {} | |||
#else | |||
template <bool NeedLoadFromConstMem> | |||
void megdnn::cuda::cutlass_wrapper:: | |||
@@ -48,15 +49,15 @@ void megdnn::cuda::cutlass_wrapper:: | |||
int* workspace, const convolution::ConvParam& param, | |||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
float scale, const GemmCoord& threadblock_shape, | |||
const GemmCoord& warp_shape, cudaStream_t stream) { | |||
const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
threadblock_k_, warp_m_, warp_n_, \ | |||
warp_k_) \ | |||
warp_k_, stage_) \ | |||
if (threadblock_shape.m() == threadblock_m_ && \ | |||
threadblock_shape.n() == threadblock_n_ && \ | |||
threadblock_shape.k() == threadblock_k_ && \ | |||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
warp_shape.k() == warp_k_) { \ | |||
warp_shape.k() == warp_k_ && stages == stage_) { \ | |||
using ThreadBlockShape = \ | |||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
threadblock_k_>; \ | |||
@@ -71,8 +72,10 @@ void megdnn::cuda::cutlass_wrapper:: | |||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
cutlass::conv::threadblock:: \ | |||
ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||
2, 16, 16, NeedLoadFromConstMem>; \ | |||
ConvolutionFpropTransThreadblockSwizzle, \ | |||
stage_, 16, 16, NeedLoadFromConstMem, \ | |||
cutlass::arch::OpMultiplyAddSaturate, \ | |||
cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ | |||
typename Convolution::ConvolutionParameter conv_param( \ | |||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
@@ -82,13 +85,15 @@ void megdnn::cuda::cutlass_wrapper:: | |||
epilogue, stream); \ | |||
} | |||
#define DISPATCH_KERNEL \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 32, 16, 64); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64, 2); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64, 2); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64, 2); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64, 2); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64, 2); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 1); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 1); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 32, 64, 32, 1); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 1); \ | |||
megdnn_assert(false, \ | |||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
"(%dx%dx%d)", \ | |||
@@ -144,7 +149,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||
float gamma, float scale, \ | |||
const GemmCoord& threadblock_shape, \ | |||
const GemmCoord& warp_shape, cudaStream_t stream); | |||
const GemmCoord& warp_shape, int stages, \ | |||
cudaStream_t stream); | |||
INST(true); | |||
INST(false); | |||
#undef INST | |||
@@ -162,7 +168,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||
float /* beta */, float /* gamma */, float /* scale */, | |||
const GemmCoord& /* threadblock_shape */, | |||
const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} | |||
const GemmCoord& /* warp_shape */, int /* stages */, | |||
cudaStream_t /* stream */) {} | |||
#else | |||
template <bool NeedLoadFromConstMem> | |||
void megdnn::cuda::cutlass_wrapper:: | |||
@@ -172,15 +179,15 @@ void megdnn::cuda::cutlass_wrapper:: | |||
int* workspace, const convolution::ConvParam& param, | |||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
float scale, const GemmCoord& threadblock_shape, | |||
const GemmCoord& warp_shape, cudaStream_t stream) { | |||
const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
threadblock_k_, warp_m_, warp_n_, \ | |||
warp_k_) \ | |||
warp_k_, stage_) \ | |||
if (threadblock_shape.m() == threadblock_m_ && \ | |||
threadblock_shape.n() == threadblock_n_ && \ | |||
threadblock_shape.k() == threadblock_k_ && \ | |||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
warp_shape.k() == warp_k_) { \ | |||
warp_shape.k() == warp_k_ && stages == stage_) { \ | |||
using ThreadBlockShape = \ | |||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
threadblock_k_>; \ | |||
@@ -196,7 +203,7 @@ void megdnn::cuda::cutlass_wrapper:: | |||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
cutlass::conv::threadblock:: \ | |||
ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||
2, 16, 16, NeedLoadFromConstMem>; \ | |||
stage_, 16, 16, NeedLoadFromConstMem>; \ | |||
typename Convolution::ConvolutionParameter conv_param( \ | |||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
@@ -206,13 +213,15 @@ void megdnn::cuda::cutlass_wrapper:: | |||
epilogue, stream); \ | |||
} | |||
#define DISPATCH_KERNEL \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 16, 32, 64); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64, 2); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64, 2); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64, 2); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64, 2); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64, 2); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 1); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 1); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 32, 64, 32, 1); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 1); \ | |||
megdnn_assert(false, \ | |||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
"(%dx%dx%d)", \ | |||
@@ -268,7 +277,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||
float gamma, float scale, \ | |||
const GemmCoord& threadblock_shape, \ | |||
const GemmCoord& warp_shape, cudaStream_t stream); | |||
const GemmCoord& warp_shape, int stages, \ | |||
cudaStream_t stream); | |||
INST(true); | |||
INST(false); | |||
#undef INST | |||
@@ -337,10 +347,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | |||
megdnn_assert(false, \ | |||
@@ -468,10 +476,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | |||
megdnn_assert(false, \ | |||
@@ -599,10 +605,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ | |||
megdnn_assert(false, \ | |||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
"(%dx%dx%d)", \ | |||
@@ -664,246 +668,6 @@ INST(true); | |||
INST(false); | |||
#undef INST | |||
/* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */ | |||
#if MEGDNN_TEGRA_X1 | |||
template <bool NeedLoadFromConstMem> | |||
void megdnn::cuda::cutlass_wrapper:: | |||
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||
const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||
int8_t* /* d_dst */, int* /* workspace */, | |||
const convolution::ConvParam& /* param */, | |||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||
float /* beta */, float /* gamma */, float /* scale */, | |||
const GemmCoord& /* threadblock_shape */, | |||
const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} | |||
#else | |||
template <bool NeedLoadFromConstMem> | |||
void megdnn::cuda::cutlass_wrapper:: | |||
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
const int8_t* d_src, const int8_t* d_filter, | |||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||
int* workspace, const convolution::ConvParam& param, | |||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
float scale, const GemmCoord& threadblock_shape, | |||
const GemmCoord& warp_shape, cudaStream_t stream) { | |||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
threadblock_k_, warp_m_, warp_n_, \ | |||
warp_k_) \ | |||
if (threadblock_shape.m() == threadblock_m_ && \ | |||
threadblock_shape.n() == threadblock_n_ && \ | |||
threadblock_shape.k() == threadblock_k_ && \ | |||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
warp_shape.k() == warp_k_) { \ | |||
using ThreadBlockShape = \ | |||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
threadblock_k_>; \ | |||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||
using Convolution = cutlass::conv::device::Convolution< \ | |||
cutlass::int4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||
cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||
ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||
cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||
cutlass::conv::ConvType::kConvolution, \ | |||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
cutlass::conv::threadblock:: \ | |||
ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||
2, 32, 32, NeedLoadFromConstMem>; \ | |||
typename Convolution::ConvolutionParameter conv_param( \ | |||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
return cutlass_convolution_wrapper<Convolution>( \ | |||
reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||
reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||
reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, \ | |||
conv_param, epilogue, stream); \ | |||
} | |||
#define DISPATCH_KERNEL \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 128, 64, 64, 128); \ | |||
megdnn_assert(false, \ | |||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
"(%dx%dx%d)", \ | |||
threadblock_shape.m(), threadblock_shape.n(), \ | |||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
warp_shape.k()); | |||
using ElementOutput = cutlass::int4b_t; | |||
using ElementAccumulator = int32_t; | |||
using ElementBias = int32_t; | |||
using ElementCompute = float; | |||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||
switch (nonlinear_mode) { | |||
case NonlineMode::IDENTITY: { | |||
using EpilogueOp = | |||
cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||
ElementCompute>; | |||
typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||
DISPATCH_KERNEL; | |||
} | |||
case NonlineMode::RELU: { | |||
using EpilogueOp = cutlass::epilogue::thread:: | |||
BiasAddLinearCombinationReluClamp< | |||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||
ElementCompute>; | |||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||
DISPATCH_KERNEL; | |||
} | |||
case NonlineMode::H_SWISH: { | |||
using EpilogueOp = cutlass::epilogue::thread:: | |||
BiasAddLinearCombinationHSwishClamp< | |||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||
ElementCompute>; | |||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||
DISPATCH_KERNEL; | |||
} | |||
default: | |||
megdnn_assert(false, | |||
"unsupported nonlinear mode for conv bias operator"); | |||
} | |||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
#undef DISPATCH_KERNEL | |||
} | |||
#endif | |||
#define INST(need_load_from_const_mem) \ | |||
template void megdnn::cuda::cutlass_wrapper:: \ | |||
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||
need_load_from_const_mem>( \ | |||
const int8_t* d_src, const int8_t* d_filter, \ | |||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||
int* workspace, const convolution::ConvParam& param, \ | |||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||
float gamma, float scale, \ | |||
const GemmCoord& threadblock_shape, \ | |||
const GemmCoord& warp_shape, cudaStream_t stream); | |||
INST(true); | |||
#undef INST | |||
/* ====== cutlass kernel wrapper for uint4 x int4 nchw64 layout ====== */ | |||
#if MEGDNN_TEGRA_X1 | |||
template <bool NeedLoadFromConstMem> | |||
void megdnn::cuda::cutlass_wrapper:: | |||
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||
const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||
uint8_t* /* d_dst */, int* /* workspace */, | |||
const convolution::ConvParam& /* param */, | |||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||
float /* beta */, float /* gamma */, float /* delta */, | |||
float /* theta */, float /* scale */, | |||
uint8_t /* src_zero_point */, | |||
const GemmCoord& /* threadblock_shape */, | |||
const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} | |||
#else | |||
template <bool NeedLoadFromConstMem> | |||
void megdnn::cuda::cutlass_wrapper:: | |||
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
const uint8_t* d_src, const int8_t* d_filter, | |||
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||
int* workspace, const convolution::ConvParam& param, | |||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
float delta, float theta, float /* scale */, | |||
uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||
const GemmCoord& warp_shape, cudaStream_t stream) { | |||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
threadblock_k_, warp_m_, warp_n_, \ | |||
warp_k_) \ | |||
if (threadblock_shape.m() == threadblock_m_ && \ | |||
threadblock_shape.n() == threadblock_n_ && \ | |||
threadblock_shape.k() == threadblock_k_ && \ | |||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
warp_shape.k() == warp_k_) { \ | |||
using ThreadBlockShape = \ | |||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
threadblock_k_>; \ | |||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||
using Convolution = cutlass::conv::device::Convolution< \ | |||
cutlass::uint4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||
cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||
ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||
cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||
cutlass::conv::ConvType::kConvolution, \ | |||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
cutlass::conv::threadblock:: \ | |||
ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||
2, 32, 32, NeedLoadFromConstMem>; \ | |||
typename Convolution::ConvolutionParameter conv_param( \ | |||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
return cutlass_convolution_wrapper<Convolution>( \ | |||
reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||
reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||
reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||
conv_param, epilogue, stream, {src_zero_point}); \ | |||
} | |||
#define DISPATCH_KERNEL \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 128, 64, 64, 128); \ | |||
megdnn_assert(false, \ | |||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
"(%dx%dx%d)", \ | |||
threadblock_shape.m(), threadblock_shape.n(), \ | |||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
warp_shape.k()); | |||
using ElementOutput = cutlass::uint4b_t; | |||
using ElementAccumulator = int32_t; | |||
using ElementBias = int32_t; | |||
using ElementCompute = float; | |||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||
switch (nonlinear_mode) { | |||
case NonlineMode::IDENTITY: { | |||
using EpilogueOp = | |||
cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||
ElementCompute>; | |||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||
delta + theta}; | |||
DISPATCH_KERNEL; | |||
} | |||
case NonlineMode::RELU: { | |||
using EpilogueOp = cutlass::epilogue::thread:: | |||
BiasAddLinearCombinationReluClamp< | |||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||
ElementCompute>; | |||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||
0, delta, theta}; | |||
DISPATCH_KERNEL; | |||
} | |||
default: | |||
megdnn_assert(false, | |||
"unsupported nonlinear mode for conv bias operator"); | |||
} | |||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
#undef DISPATCH_KERNEL | |||
} | |||
#endif | |||
#define INST(need_load_from_const_mem) \ | |||
template void megdnn::cuda::cutlass_wrapper:: \ | |||
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||
need_load_from_const_mem>( \ | |||
const uint8_t* d_src, const int8_t* d_filter, \ | |||
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||
int* workspace, const convolution::ConvParam& param, \ | |||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||
float gamma, float delta, float theta, float scale, \ | |||
uint8_t src_zero_point, \ | |||
const GemmCoord& threadblock_shape, \ | |||
const GemmCoord& warp_shape, cudaStream_t stream); | |||
INST(true); | |||
#undef INST | |||
/* ===== cutlass kernel wrapper for nchw4 layout and nhwc output ===== */ | |||
#if MEGDNN_TEGRA_X1 | |||
template <bool signedness> | |||
@@ -970,10 +734,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | |||
megdnn_assert(false, \ | |||
@@ -1039,262 +801,4 @@ INST(true); | |||
INST(false); | |||
#undef INST | |||
/* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */ | |||
#if MEGDNN_TEGRA_X1 | |||
template <bool NeedLoadFromConstMem> | |||
void megdnn::cuda::cutlass_wrapper:: | |||
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||
const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||
const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||
int8_t* /* d_dst */, int* /* workspace */, | |||
const convolution::ConvParam& /* param */, | |||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||
float /* beta */, float /* gamma */, float /* scale */, | |||
const GemmCoord& /* threadblock_shape */, | |||
const GemmCoord& /* warp_shape */, | |||
const int32_t /* access_size */, cudaStream_t /* stream */) {} | |||
#else | |||
template <bool NeedLoadFromConstMem> | |||
void megdnn::cuda::cutlass_wrapper:: | |||
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||
const int8_t* d_src, const int8_t* d_filter, | |||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||
int* workspace, const convolution::ConvParam& param, | |||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
float scale, const GemmCoord& threadblock_shape, | |||
const GemmCoord& warp_shape, const int32_t access_size, | |||
cudaStream_t stream) { | |||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
threadblock_k_, warp_m_, warp_n_, \ | |||
warp_k_, access_size_) \ | |||
if (threadblock_shape.m() == threadblock_m_ && \ | |||
threadblock_shape.n() == threadblock_n_ && \ | |||
threadblock_shape.k() == threadblock_k_ && \ | |||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
warp_shape.k() == warp_k_ && access_size == access_size_) { \ | |||
using ThreadBlockShape = \ | |||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
threadblock_k_>; \ | |||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||
using Convolution = cutlass::conv::device::Convolution< \ | |||
cutlass::int4b_t, cutlass::layout::TensorNHWC, \ | |||
cutlass::int4b_t, cutlass::layout::TensorNCxHWx<access_size_>, \ | |||
ElementOutput, cutlass::layout::TensorNHWC, int32_t, \ | |||
cutlass::layout::TensorNHWC, int32_t, \ | |||
cutlass::conv::ConvType::kConvolution, \ | |||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
cutlass::conv::threadblock:: \ | |||
ConvolutionFpropNHWCThreadblockSwizzle, \ | |||
2, access_size_, access_size_, NeedLoadFromConstMem, \ | |||
cutlass::arch::OpMultiplyAddSaturate, \ | |||
cutlass::conv::ImplicitGemmMode::GEMM_TN>; \ | |||
typename Convolution::ConvolutionParameter conv_param( \ | |||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
return cutlass_convolution_wrapper<Convolution>( \ | |||
reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||
reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||
reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, \ | |||
conv_param, epilogue, stream); \ | |||
} | |||
#define DISPATCH_KERNEL \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 32); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 8); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 32); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 8); \ | |||
megdnn_assert(false, \ | |||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
"(%dx%dx%d) and access_size (%d)", \ | |||
threadblock_shape.m(), threadblock_shape.n(), \ | |||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
warp_shape.k(), access_size); | |||
using ElementOutput = cutlass::int4b_t; | |||
using ElementAccumulator = int32_t; | |||
using ElementBias = int32_t; | |||
using ElementCompute = float; | |||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||
switch (nonlinear_mode) { | |||
case NonlineMode::IDENTITY: { | |||
using EpilogueOp = | |||
cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
ElementOutput, 8, ElementAccumulator, ElementBias, | |||
ElementCompute>; | |||
typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||
DISPATCH_KERNEL; | |||
} | |||
case NonlineMode::RELU: { | |||
using EpilogueOp = cutlass::epilogue::thread:: | |||
BiasAddLinearCombinationReluClamp< | |||
ElementOutput, 8, ElementAccumulator, ElementBias, | |||
ElementCompute>; | |||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||
DISPATCH_KERNEL; | |||
} | |||
case NonlineMode::H_SWISH: { | |||
using EpilogueOp = cutlass::epilogue::thread:: | |||
BiasAddLinearCombinationHSwishClamp< | |||
ElementOutput, 8, ElementAccumulator, ElementBias, | |||
ElementCompute>; | |||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||
DISPATCH_KERNEL; | |||
} | |||
default: | |||
megdnn_assert(false, | |||
"unsupported nonlinear mode for conv bias operator"); | |||
} | |||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
#undef DISPATCH_KERNEL | |||
} | |||
#endif | |||
#define INST(need_load_from_const_mem) \ | |||
template void megdnn::cuda::cutlass_wrapper:: \ | |||
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc< \ | |||
need_load_from_const_mem>( \ | |||
const int8_t* d_src, const int8_t* d_filter, \ | |||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||
int* workspace, const convolution::ConvParam& param, \ | |||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||
float gamma, float scale, \ | |||
const GemmCoord& threadblock_shape, \ | |||
const GemmCoord& warp_shape, const int32_t access_size, \ | |||
cudaStream_t stream); | |||
INST(true); | |||
INST(false); | |||
#undef INST | |||
#if MEGDNN_TEGRA_X1 | |||
template <bool NeedLoadFromConstMem> | |||
void megdnn::cuda::cutlass_wrapper:: | |||
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||
const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||
const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||
uint8_t* /* d_dst */, int* /* workspace */, | |||
const convolution::ConvParam& /* param */, | |||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||
float /* beta */, float /* gamma */, float /* delta */, | |||
float /* theta */, float /* scale */, | |||
uint8_t /* src_zero_point */, | |||
const GemmCoord& /* threadblock_shape */, | |||
const GemmCoord& /* warp_shape */, | |||
const int32_t /* access_size */, cudaStream_t /* stream */) {} | |||
#else | |||
template <bool NeedLoadFromConstMem> | |||
void megdnn::cuda::cutlass_wrapper:: | |||
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||
const uint8_t* d_src, const int8_t* d_filter, | |||
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||
int* workspace, const convolution::ConvParam& param, | |||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
float delta, float theta, float /* scale */, | |||
uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||
const GemmCoord& warp_shape, const int32_t access_size, | |||
cudaStream_t stream) { | |||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
threadblock_k_, warp_m_, warp_n_, \ | |||
warp_k_, access_size_) \ | |||
if (threadblock_shape.m() == threadblock_m_ && \ | |||
threadblock_shape.n() == threadblock_n_ && \ | |||
threadblock_shape.k() == threadblock_k_ && \ | |||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
warp_shape.k() == warp_k_ && access_size == access_size_) { \ | |||
using ThreadBlockShape = \ | |||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
threadblock_k_>; \ | |||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||
using Convolution = cutlass::conv::device::Convolution< \ | |||
cutlass::uint4b_t, cutlass::layout::TensorNHWC, \ | |||
cutlass::int4b_t, cutlass::layout::TensorNCxHWx<access_size_>, \ | |||
ElementOutput, cutlass::layout::TensorNHWC, int32_t, \ | |||
cutlass::layout::TensorNHWC, int32_t, \ | |||
cutlass::conv::ConvType::kConvolution, \ | |||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
cutlass::conv::threadblock:: \ | |||
ConvolutionFpropNHWCThreadblockSwizzle, \ | |||
2, access_size_, access_size_, NeedLoadFromConstMem, \ | |||
cutlass::arch::OpMultiplyAddSaturate, \ | |||
cutlass::conv::ImplicitGemmMode::GEMM_TN>; \ | |||
typename Convolution::ConvolutionParameter conv_param( \ | |||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
return cutlass_convolution_wrapper<Convolution>( \ | |||
reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||
reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||
reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||
conv_param, epilogue, stream, {src_zero_point}); \ | |||
} | |||
#define DISPATCH_KERNEL \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 32); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 8); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 32); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 16); \ | |||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 8); \ | |||
megdnn_assert(false, \ | |||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
"(%dx%dx%d) and access_size (%d)", \ | |||
threadblock_shape.m(), threadblock_shape.n(), \ | |||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
warp_shape.k(), access_size); | |||
using ElementOutput = cutlass::uint4b_t; | |||
using ElementAccumulator = int32_t; | |||
using ElementBias = int32_t; | |||
using ElementCompute = float; | |||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||
switch (nonlinear_mode) { | |||
case NonlineMode::IDENTITY: { | |||
using EpilogueOp = | |||
cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
ElementOutput, 8, ElementAccumulator, ElementBias, | |||
ElementCompute>; | |||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||
delta + theta}; | |||
DISPATCH_KERNEL; | |||
} | |||
case NonlineMode::RELU: { | |||
using EpilogueOp = cutlass::epilogue::thread:: | |||
BiasAddLinearCombinationReluClamp< | |||
ElementOutput, 8, ElementAccumulator, ElementBias, | |||
ElementCompute>; | |||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||
0, delta, theta}; | |||
DISPATCH_KERNEL; | |||
} | |||
default: | |||
megdnn_assert(false, | |||
"unsupported nonlinear mode for conv bias operator"); | |||
} | |||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
#undef DISPATCH_KERNEL | |||
} | |||
#endif | |||
#define INST(need_load_from_const_mem) \ | |||
template void megdnn::cuda::cutlass_wrapper:: \ | |||
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc< \ | |||
need_load_from_const_mem>( \ | |||
const uint8_t* d_src, const int8_t* d_filter, \ | |||
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||
int* workspace, const convolution::ConvParam& param, \ | |||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||
float gamma, float delta, float theta, float scale, \ | |||
uint8_t src_zero_point, \ | |||
const GemmCoord& threadblock_shape, \ | |||
const GemmCoord& warp_shape, const int32_t access_size, \ | |||
cudaStream_t stream); | |||
INST(true); | |||
INST(false); | |||
#undef INST | |||
// vim: syntax=cuda.doxygen |
@@ -0,0 +1,194 @@ | |||
/** | |||
* \file dnn/src/cuda/conv_bias/cutlass_reorder_filter.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | |||
#include "src/cuda/query_blocksize.cuh" | |||
#include "src/cuda/integer_subbyte_utils.cuh" | |||
using namespace megdnn; | |||
using namespace cuda; | |||
using namespace cutlass_wrapper; | |||
namespace { | |||
template <uint32_t size_bits, uint32_t interleaved> | |||
__device__ __forceinline__ void reorder_ncxhwx_imma_filter_func( | |||
int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH, | |||
uint32_t FW, uint32_t lane, bool trans_oc) { | |||
static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||
static constexpr uint32_t threads_per_interleaved = | |||
interleaved / elements_per_lane; | |||
static constexpr uint32_t instruction_shape_col = 8; | |||
// 4 threads per Quad | |||
static constexpr uint32_t elements_per_thread = instruction_shape_col / 4; | |||
// 4 threads per Quad | |||
static constexpr uint32_t reordered_elements_per_thread = interleaved / 4; | |||
uint32_t id = lane / threads_per_interleaved; | |||
uint32_t residue = lane % threads_per_interleaved; | |||
uint32_t ICx = IC / interleaved; | |||
uint32_t row = id / (ICx * FH * FW); | |||
uint32_t col = id - row * ICx * FH * FW; | |||
// transpose ncxhwx to cxhwnx | |||
uint32_t src_offset = id * interleaved + residue * elements_per_lane; | |||
row = (trans_oc) ? (row / interleaved) * interleaved + | |||
((row % reordered_elements_per_thread) / | |||
elements_per_thread) * | |||
instruction_shape_col + | |||
((row % interleaved) / | |||
reordered_elements_per_thread) * | |||
elements_per_thread + | |||
(row % elements_per_thread) | |||
: row; | |||
uint32_t dst_offset = | |||
(col * OC + row) * interleaved + residue * elements_per_lane; | |||
*(reinterpret_cast<int4*>(dst + dst_offset * size_bits / 8)) = | |||
*(reinterpret_cast<const int4*>(src + src_offset * size_bits / 8)); | |||
} | |||
template <uint32_t size_bits, uint32_t interleaved> | |||
__global__ void reorder_ncxhwx_imma_filter_kernel( | |||
int8_t* __restrict__ dst_filter, const int8_t* __restrict__ src_filter, | |||
uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, bool trans_oc) { | |||
static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||
const uint32_t size = OC * IC * FH * FW / elements_per_lane; | |||
uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x; | |||
if (lane < size) { | |||
reorder_ncxhwx_imma_filter_func<size_bits, interleaved>( | |||
dst_filter, src_filter, OC, IC, FH, FW, lane, trans_oc); | |||
} | |||
} | |||
template <uint32_t size_bits, uint32_t alignbits, uint32_t interleaved> | |||
__device__ __forceinline__ void reorder_nhwc_imma_filter_func( | |||
int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH, | |||
uint32_t FW, uint32_t lane, bool trans_oc) { | |||
static constexpr uint32_t elements_per_access = alignbits / size_bits; | |||
static constexpr uint32_t instruction_shape_col = 8; | |||
// 4 threads per Quad | |||
static constexpr uint32_t elements_per_thread = instruction_shape_col / 4; | |||
// 4 threads per Quad | |||
static constexpr uint32_t reordered_elements_per_thread = interleaved / 4; | |||
uint32_t ICx = IC / elements_per_access; | |||
uint32_t k = lane / (ICx * FH * FW); | |||
uint32_t cxrs = lane - k * ICx * FH * FW; | |||
uint32_t rs = cxrs / ICx; | |||
uint32_t cx = cxrs - rs * ICx; | |||
// transpose nhwc to ncxhwx | |||
uint32_t src_offset = lane * elements_per_access; | |||
// reorder k | |||
k = (trans_oc) | |||
? (k / interleaved) * interleaved + | |||
((k % reordered_elements_per_thread) / | |||
elements_per_thread) * | |||
instruction_shape_col + | |||
((k % interleaved) / reordered_elements_per_thread) * | |||
elements_per_thread + | |||
(k % elements_per_thread) | |||
: k; | |||
uint32_t dst_offset = | |||
(k * ICx * FH * FW + cx * FH * FW + rs) * elements_per_access; | |||
if (alignbits == 32) { | |||
*(reinterpret_cast<int*>(dst + dst_offset * size_bits / 8)) = *( | |||
reinterpret_cast<const int*>(src + src_offset * size_bits / 8)); | |||
} else if (alignbits == 64) { | |||
*(reinterpret_cast<int2*>(dst + dst_offset * size_bits / 8)) = | |||
*(reinterpret_cast<const int2*>(src + | |||
src_offset * size_bits / 8)); | |||
} else { | |||
*(reinterpret_cast<int4*>(dst + dst_offset * size_bits / 8)) = | |||
*(reinterpret_cast<const int4*>(src + | |||
src_offset * size_bits / 8)); | |||
} | |||
} | |||
template <uint32_t size_bits, uint32_t alignbits, uint32_t interleaved> | |||
__global__ void reorder_nhwc_imma_filter_kernel( | |||
int8_t* __restrict__ dst_filter, const int8_t* __restrict__ src_filter, | |||
uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, bool trans_oc) { | |||
static constexpr uint32_t elements_per_access = alignbits / size_bits; | |||
const uint32_t size = OC * IC * FH * FW / elements_per_access; | |||
uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x; | |||
if (lane < size) { | |||
reorder_nhwc_imma_filter_func<size_bits, alignbits, interleaved>( | |||
dst_filter, src_filter, OC, IC, FH, FW, lane, trans_oc); | |||
} | |||
} | |||
} // namespace | |||
template <uint32_t size_bits, uint32_t interleaved> | |||
void megdnn::cuda::cutlass_wrapper::reorder_ncxhwx_imma_filter( | |||
int8_t* dst_filter, const int8_t* src_filter, uint32_t OC, uint32_t IC, | |||
uint32_t FH, uint32_t FW, bool trans_oc, cudaStream_t stream) { | |||
static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||
uint32_t nr_threads = | |||
query_blocksize_for_kernel(reinterpret_cast<const void*>( | |||
reorder_ncxhwx_imma_filter_kernel<size_bits, interleaved>)); | |||
uint32_t vthreads = DIVUP(OC * IC * FH * FW, elements_per_lane); | |||
nr_threads = std::min(nr_threads, vthreads); | |||
uint32_t nr_blocks = DIVUP(vthreads, nr_threads); | |||
reorder_ncxhwx_imma_filter_kernel<size_bits, interleaved> | |||
<<<nr_blocks, nr_threads, 0, stream>>>(dst_filter, src_filter, OC, | |||
IC, FH, FW, trans_oc); | |||
after_kernel_launch(); | |||
} | |||
template <uint32_t size_bits, uint32_t alignbits> | |||
void megdnn::cuda::cutlass_wrapper::reorder_nhwc_imma_filter( | |||
int8_t* dst_filter, const int8_t* src_filter, uint32_t OC, uint32_t IC, | |||
uint32_t FH, uint32_t FW, bool trans_oc, uint32_t oc_interleaved, | |||
cudaStream_t stream) { | |||
static constexpr uint32_t elements_per_access = alignbits / size_bits; | |||
uint32_t nr_threads = | |||
query_blocksize_for_kernel(reinterpret_cast<const void*>( | |||
reorder_nhwc_imma_filter_kernel<size_bits, alignbits, 32>)); | |||
uint32_t vthreads = DIVUP(OC * IC * FH * FW, elements_per_access); | |||
nr_threads = std::min(nr_threads, vthreads); | |||
uint32_t nr_blocks = DIVUP(vthreads, nr_threads); | |||
if (oc_interleaved == 32) { | |||
reorder_nhwc_imma_filter_kernel<size_bits, alignbits, 32> | |||
<<<nr_blocks, nr_threads, 0, stream>>>( | |||
dst_filter, src_filter, OC, IC, FH, FW, trans_oc); | |||
} else { | |||
reorder_nhwc_imma_filter_kernel<size_bits, alignbits, 64> | |||
<<<nr_blocks, nr_threads, 0, stream>>>( | |||
dst_filter, src_filter, OC, IC, FH, FW, trans_oc); | |||
} | |||
after_kernel_launch(); | |||
} | |||
#define INST(_size_bits, _interleaved) \ | |||
template void megdnn::cuda::cutlass_wrapper::reorder_ncxhwx_imma_filter< \ | |||
_size_bits, _interleaved>(int8_t * dst_filter, \ | |||
const int8_t* src_filter, uint32_t OC, \ | |||
uint32_t IC, uint32_t FH, uint32_t FW, \ | |||
bool trans_oc, cudaStream_t stream); | |||
INST(8, 32) | |||
INST(4, 64) | |||
#undef INST | |||
#define INST(_size_bits, _alignbits) \ | |||
template void megdnn::cuda::cutlass_wrapper::reorder_nhwc_imma_filter< \ | |||
_size_bits, _alignbits>( \ | |||
int8_t * dst_filter, const int8_t* src_filter, uint32_t OC, \ | |||
uint32_t IC, uint32_t FH, uint32_t FW, bool trans_oc, \ | |||
uint32_t oc_interleaved, cudaStream_t stream); | |||
INST(4, 32) | |||
INST(4, 64) | |||
INST(4, 128) | |||
#undef INST | |||
// vim: syntax=cuda.doxygen |
@@ -0,0 +1,33 @@ | |||
/** | |||
* \file dnn/src/cuda/conv_bias/cutlass_reorder_filter.cuh | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "src/cuda/utils.cuh" | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace cutlass_wrapper { | |||
template <uint32_t size_bits, uint32_t interleaved> | |||
void reorder_ncxhwx_imma_filter(int8_t* dst_filter, const int8_t* src_filter, | |||
uint32_t OC, uint32_t IC, uint32_t FH, | |||
uint32_t FW, bool trans_oc, | |||
cudaStream_t stream); | |||
template <uint32_t size_bits, uint32_t alignbits> | |||
void reorder_nhwc_imma_filter(int8_t* dst_filter, const int8_t* src_filter, | |||
uint32_t OC, uint32_t IC, uint32_t FH, | |||
uint32_t FW, bool trans_oc, | |||
uint32_t oc_interleaved, cudaStream_t stream); | |||
} // namespace cutlass_wrapper | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -102,7 +102,7 @@ void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::do_exec( | |||
reinterpret_cast<int8_t*>(z_ptr), | |||
reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||
threadblock_shape, warp_shape, stream); | |||
threadblock_shape, warp_shape, m_algo_param.stage, stream); | |||
} | |||
#endif | |||
@@ -104,7 +104,7 @@ void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||
reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||
threadblock_shape, warp_shape, m_algo_param.access_size, | |||
stream); | |||
m_algo_param.stage, stream); | |||
} else { | |||
cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_nhwc<true>( | |||
reinterpret_cast<int8_t*>(args.src_tensor->raw_ptr), | |||
@@ -114,7 +114,7 @@ void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||
reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||
threadblock_shape, warp_shape, m_algo_param.access_size, | |||
stream); | |||
m_algo_param.stage, stream); | |||
} | |||
} | |||
#endif | |||
@@ -12,6 +12,7 @@ | |||
#include "./algo.h" | |||
#include "src/common/conv_bias.h" | |||
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | |||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
#include "src/cuda/conv_bias/reduce_filter.cuh" | |||
#include "src/cuda/convolution_helper/parameter.cuh" | |||
@@ -121,41 +122,26 @@ void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::exec( | |||
std::string ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::to_string( | |||
AlgoParam algo_param) { | |||
return ssprintf("%dX%dX%d_%dX%dX%d", algo_param.threadblock_m, | |||
return ssprintf("%dX%dX%d_%dX%dX%d_%d", algo_param.threadblock_m, | |||
algo_param.threadblock_n, algo_param.threadblock_k, | |||
algo_param.warp_m, algo_param.warp_n, algo_param.warp_k); | |||
algo_param.warp_m, algo_param.warp_n, algo_param.warp_k, | |||
algo_param.stage); | |||
} | |||
void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::reorder_filter( | |||
const ExecArgs& args, void* reordered_filter) const { | |||
auto&& param = args.opr->param(); | |||
size_t ci = args.src_layout->operator[](1) * 64; | |||
size_t co = args.dst_layout->operator[](1) * 64; | |||
auto&& fm = args.filter_meta; | |||
size_t n = args.src_layout->operator[](0), | |||
ci = args.src_layout->operator[](1) * 64, | |||
hi = args.src_layout->operator[](2), | |||
wi = args.src_layout->operator[](3); | |||
size_t co = args.dst_layout->operator[](1) * 64, | |||
ho = args.dst_layout->operator[](2), | |||
wo = args.dst_layout->operator[](3); | |||
UNPACK_CONV_PARAMETER(fm, param); | |||
MARK_USED_VAR; | |||
// filter: KCRS64 => CRSK64 | |||
TensorLayout src{{co, ci / 64, fh, fw, 64}, dtype::QuantizedS4()}; | |||
src.init_contiguous_stride(); | |||
TensorLayout dst = src; | |||
dst.stride[0] = 64; | |||
dst.stride[1] = co * fh * fw * 64; | |||
dst.stride[2] = co * fw * 64; | |||
dst.stride[3] = co * 64; | |||
dst.stride[4] = 1; | |||
TensorND ts_src, ts_dst; | |||
ts_src.raw_ptr = args.filter_tensor->raw_ptr; | |||
ts_src.layout = src; | |||
ts_dst.raw_ptr = reordered_filter; | |||
ts_dst.layout = dst; | |||
auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | |||
transpose->exec(ts_src, ts_dst); | |||
size_t fh = fm.spatial[0], fw = fm.spatial[1]; | |||
cudaStream_t stream = cuda_stream(args.opr->handle()); | |||
// filter: KCRS64 => CRSK64 and reorder oc | |||
cutlass_wrapper::reorder_ncxhwx_imma_filter<4, 64>( | |||
reinterpret_cast<int8_t*>(reordered_filter), | |||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, fh, | |||
fw, true, stream); | |||
} | |||
#endif | |||
@@ -12,6 +12,7 @@ | |||
#include "./algo.h" | |||
#include "src/common/conv_bias.h" | |||
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | |||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
#include "src/cuda/conv_bias/reduce_filter.cuh" | |||
#include "src/cuda/convolution_helper/parameter.cuh" | |||
@@ -128,10 +129,10 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::exec( | |||
std::string ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::to_string( | |||
AlgoParam algo_param) { | |||
return ssprintf("%dX%dX%d_%dX%dX%d_%d", algo_param.threadblock_m, | |||
return ssprintf("%dX%dX%d_%dX%dX%d_%d_%d", algo_param.threadblock_m, | |||
algo_param.threadblock_n, algo_param.threadblock_k, | |||
algo_param.warp_m, algo_param.warp_n, algo_param.warp_k, | |||
algo_param.access_size); | |||
algo_param.stage, algo_param.access_size); | |||
} | |||
void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::reorder_filter( | |||
@@ -142,17 +143,32 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::reorder_filter( | |||
fh = args.filter_layout->operator[](1), | |||
fw = args.filter_layout->operator[](2); | |||
// reformat grad from nhwc to ncxhwx | |||
TensorLayout exec_src{{co, fh, fw, ci / iterleaved, (size_t)iterleaved / 2}, | |||
dtype::Int8()}; | |||
TensorLayout exec_dst{{co, ci / iterleaved, fh, fw, (size_t)iterleaved / 2}, | |||
dtype::Int8()}; | |||
exec_src = exec_src.dimshuffle({0, 3, 1, 2, 4}); | |||
cudaStream_t stream = cuda_stream(args.opr->handle()); | |||
auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>(); | |||
relayout->exec({args.filter_tensor->raw_ptr, exec_src}, | |||
{reordered_filter, exec_dst}); | |||
// reformat filter from nhwc to ncxhwx and reorder oc | |||
// use trans_oc threadblock_n must be 32 or 64 | |||
bool trans_oc = ((co % m_algo_param.threadblock_n == 0) && | |||
(m_algo_param.threadblock_n == 32 || | |||
m_algo_param.threadblock_n == 64)); | |||
uint32_t oc_iterleave = (m_algo_param.threadblock_n == 64) ? 64 : 32; | |||
if (iterleaved == 8) { | |||
cutlass_wrapper::reorder_nhwc_imma_filter<4, 32>( | |||
reinterpret_cast<int8_t*>(reordered_filter), | |||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, | |||
fh, fw, trans_oc, oc_iterleave, stream); | |||
} else if (iterleaved == 16) { | |||
cutlass_wrapper::reorder_nhwc_imma_filter<4, 64>( | |||
reinterpret_cast<int8_t*>(reordered_filter), | |||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, | |||
fh, fw, trans_oc, oc_iterleave, stream); | |||
} else { | |||
megdnn_assert(iterleaved == 32); | |||
cutlass_wrapper::reorder_nhwc_imma_filter<4, 128>( | |||
reinterpret_cast<int8_t*>(reordered_filter), | |||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, | |||
fh, fw, trans_oc, oc_iterleave, stream); | |||
} | |||
} | |||
#endif | |||
@@ -11,6 +11,7 @@ | |||
*/ | |||
#include "./algo.h" | |||
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | |||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
#include "src/cuda/convolution_helper/parameter.cuh" | |||
#include "src/cuda/utils.h" | |||
@@ -110,11 +111,14 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||
size_t ho = args.dst_layout->operator[](2), | |||
wo = args.dst_layout->operator[](3); | |||
size_t co; | |||
bool trans_oc; | |||
if (param.format == Format::NCHW32) { | |||
co = args.dst_layout->operator[](1) * 32; | |||
trans_oc = true; | |||
} else { | |||
megdnn_assert(param.format == Format::NCHW32_NCHW4); | |||
co = args.dst_layout->operator[](1) * 4; | |||
trans_oc = false; | |||
} | |||
UNPACK_CONV_PARAMETER(fm, param); | |||
MARK_USED_VAR | |||
@@ -123,23 +127,11 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||
int8_t* filter_ptr = nullptr; | |||
if (args.preprocessed_filter == nullptr) { | |||
filter_ptr = reinterpret_cast<int8_t*>(args.workspace.raw_ptr); | |||
// reformat filter from nchw32 to chwn32 | |||
TensorLayout src{{co, ci / 32, fh, fw, 32}, dtype::Int8()}; | |||
src.init_contiguous_stride(); | |||
TensorLayout dst = src; | |||
dst.stride[0] = 32; | |||
dst.stride[1] = co * fh * fw * 32; | |||
dst.stride[2] = co * fw * 32; | |||
dst.stride[3] = co * 32; | |||
dst.stride[4] = 1; | |||
TensorND ts_src, ts_dst; | |||
ts_src.raw_ptr = args.filter_tensor->raw_ptr; | |||
ts_src.layout = src; | |||
ts_dst.raw_ptr = args.workspace.raw_ptr; | |||
ts_dst.layout = dst; | |||
auto&& transpose = | |||
args.opr->handle()->create_operator<RelayoutForward>(); | |||
transpose->exec(ts_src, ts_dst); | |||
// filter: KCRS32 => CRSK32 and reorder oc | |||
cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>( | |||
filter_ptr, | |||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, | |||
fh, fw, trans_oc, stream); | |||
} else { | |||
filter_ptr = reinterpret_cast<int8_t*>( | |||
args.preprocessed_filter->tensors[0].raw_ptr); | |||
@@ -182,7 +174,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||
m_algo_param.warp_n, | |||
m_algo_param.warp_k}, | |||
stream); | |||
m_algo_param.stage, stream); | |||
} else { | |||
megdnn_assert(param.format == Format::NCHW32_NCHW4); | |||
cutlass_wrapper:: | |||
@@ -202,7 +194,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||
m_algo_param.warp_n, | |||
m_algo_param.warp_k}, | |||
stream); | |||
m_algo_param.stage, stream); | |||
} | |||
} else { | |||
if (param.format == Format::NCHW32) { | |||
@@ -218,7 +210,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||
m_algo_param.warp_n, | |||
m_algo_param.warp_k}, | |||
stream); | |||
m_algo_param.stage, stream); | |||
} else { | |||
megdnn_assert(param.format == Format::NCHW32_NCHW4); | |||
cutlass_wrapper:: | |||
@@ -238,7 +230,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||
m_algo_param.warp_n, | |||
m_algo_param.warp_k}, | |||
stream); | |||
m_algo_param.stage, stream); | |||
} | |||
} | |||
after_kernel_launch(); | |||
@@ -246,9 +238,10 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||
std::string ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::to_string( | |||
AlgoParam algo_param) { | |||
return ssprintf("%uX%uX%u_%uX%uX%u", algo_param.threadblock_m, | |||
return ssprintf("%uX%uX%u_%uX%uX%u_%u", algo_param.threadblock_m, | |||
algo_param.threadblock_n, algo_param.threadblock_k, | |||
algo_param.warp_m, algo_param.warp_n, algo_param.warp_k); | |||
algo_param.warp_m, algo_param.warp_n, algo_param.warp_k, | |||
algo_param.stage); | |||
} | |||
size_t ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm:: | |||
@@ -267,36 +260,26 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec_preprocess( | |||
using Format = Param::Format; | |||
auto&& param = args.opr->param(); | |||
auto&& fm = args.filter_meta; | |||
size_t n = args.src_layout->operator[](0), | |||
ci = args.src_layout->operator[](1) * 32, | |||
hi = args.src_layout->operator[](2), | |||
wi = args.src_layout->operator[](3); | |||
size_t ho = args.dst_layout->operator[](2), | |||
wo = args.dst_layout->operator[](3); | |||
size_t ci = args.src_layout->operator[](1) * 32; | |||
size_t co; | |||
bool trans_oc; | |||
if (param.format == Format::NCHW32) { | |||
co = args.dst_layout->operator[](1) * 32; | |||
trans_oc = true; | |||
} else { | |||
megdnn_assert(param.format == Format::NCHW32_NCHW4); | |||
co = args.dst_layout->operator[](1) * 4; | |||
trans_oc = false; | |||
} | |||
UNPACK_CONV_PARAMETER(fm, param); | |||
MARK_USED_VAR | |||
TensorLayout src{{co, ci / 32, fh, fw, 32}, dtype::Int8()}; | |||
src.init_contiguous_stride(); | |||
TensorLayout dst = src; | |||
dst.stride[0] = 32; | |||
dst.stride[1] = co * fh * fw * 32; | |||
dst.stride[2] = co * fw * 32; | |||
dst.stride[3] = co * 32; | |||
dst.stride[4] = 1; | |||
TensorND ts_src, ts_dst; | |||
ts_src.raw_ptr = args.filter_tensor->raw_ptr; | |||
ts_src.layout = src; | |||
ts_dst.raw_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||
ts_dst.layout = dst; | |||
auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | |||
transpose->exec(ts_src, ts_dst); | |||
size_t fh = fm.spatial[0], fw = fm.spatial[1]; | |||
cudaStream_t stream = cuda_stream(args.opr->handle()); | |||
// filter: KCRS32 => CRSK32 and reorder oc | |||
cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>( | |||
reinterpret_cast<int8_t*>( | |||
args.preprocessed_filter->tensors[0].raw_ptr), | |||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, fh, | |||
fw, trans_oc, stream); | |||
} | |||
#endif | |||
@@ -144,7 +144,8 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::do_exec( | |||
reinterpret_cast<uint8_t*>(z_ptr), | |||
reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||
kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | |||
dst_scale, src_zero, threadblock_shape, warp_shape, stream); | |||
dst_scale, src_zero, threadblock_shape, warp_shape, | |||
m_algo_param.stage, stream); | |||
} | |||
void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::update_bias( | |||
@@ -147,7 +147,7 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||
reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||
kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | |||
dst_scale, src_zero, threadblock_shape, warp_shape, | |||
m_algo_param.access_size, stream); | |||
m_algo_param.access_size, m_algo_param.stage, stream); | |||
} else { | |||
cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc<true>( | |||
reinterpret_cast<uint8_t*>(args.src_tensor->raw_ptr), | |||
@@ -157,7 +157,7 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||
reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||
kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | |||
dst_scale, src_zero, threadblock_shape, warp_shape, | |||
m_algo_param.access_size, stream); | |||
m_algo_param.access_size, m_algo_param.stage, stream); | |||
} | |||
} | |||
@@ -840,21 +840,21 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_IMMA) { | |||
param.pad_h = param.pad_w = 1; | |||
param.stride_h = param.stride_w = 1; | |||
param.format = param::ConvBias::Format::NCHW32; | |||
checker.set_param(param).execs({{16, 16, 7, 7, 32}, | |||
{512, 16, 3, 3, 32}, | |||
{1, 16, 1, 1, 32}, | |||
checker.set_param(param).execs({{16, 8, 7, 7, 32}, | |||
{256, 8, 3, 3, 32}, | |||
{1, 8, 1, 1, 32}, | |||
{}, | |||
{}}); | |||
param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||
checker.set_param(param).execs({{16, 16, 7, 7, 32}, | |||
{512, 16, 1, 1, 32}, | |||
{1, 16, 1, 1, 32}, | |||
checker.set_param(param).execs({{16, 8, 7, 7, 32}, | |||
{256, 8, 1, 1, 32}, | |||
{1, 8, 1, 1, 32}, | |||
{}, | |||
{}}); | |||
param.nonlineMode = param::ConvBias::NonlineMode::H_SWISH; | |||
checker.set_param(param).execs({{16, 16, 7, 7, 32}, | |||
{512, 16, 3, 3, 32}, | |||
{1, 16, 1, 1, 32}, | |||
checker.set_param(param).execs({{16, 8, 7, 7, 32}, | |||
{256, 8, 3, 3, 32}, | |||
{1, 8, 1, 1, 32}, | |||
{}, | |||
{}}); | |||
// use non integer scale | |||
@@ -867,18 +867,18 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_IMMA) { | |||
.set_epsilon(1 + 1e-3) | |||
.set_max_avg_error(1e-1) | |||
.set_max_avg_biased_error(1e-1) | |||
.execs({{16, 16, 7, 7, 32}, | |||
{512, 16, 3, 3, 32}, | |||
{1, 16, 1, 1, 32}, | |||
{16, 16, 7, 7, 32}, | |||
.execs({{16, 8, 7, 7, 32}, | |||
{256, 8, 3, 3, 32}, | |||
{1, 8, 1, 1, 32}, | |||
{16, 8, 7, 7, 32}, | |||
{}}); | |||
}; | |||
std::string algo = ConvBias::algo_name<ConvBias::DirectParam>( | |||
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_256X128X64_64X64X64", | |||
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_128X128X64_64X64X64_2", | |||
ConvBias::DirectParam{}); | |||
check(algo); | |||
algo = ConvBias::algo_name<ConvBias::DirectParam>( | |||
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_32X64X64_32X16X64", | |||
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_128X32X32_64X32X32_1", | |||
ConvBias::DirectParam{}); | |||
check(algo); | |||
} | |||
@@ -969,7 +969,7 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_NCHW4) { | |||
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker< | |||
ConvBiasForward>( | |||
ConvBias::algo_name<ConvBias::DirectParam>( | |||
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_256X128X64_64X64X64", | |||
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_128X128X64_64X64X64_2", | |||
ConvBias::DirectParam{}) | |||
.c_str())); | |||
checker.set_dtype(0, dtype::QuantizedS8(1.9980618f)) | |||