GitOrigin-RevId: 025c591f75
tags/v1.6.0-rc1
@@ -5,6 +5,8 @@ genrule( | |||
outs = cutlass_gen_list, | |||
cmd = """GEN=$(location //brain/megbrain/dnn/scripts/cutlass_generator:generator.py) | |||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type simt $(@D) | |||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type tensorop884 $(@D) | |||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type tensorop1688 $(@D) | |||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemv --type simt $(@D) | |||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type simt $(@D) | |||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type simt $(@D) | |||
@@ -252,7 +252,8 @@ def GeneratesGemm(tile, data_type, layout_a, layout_b, layout_c, min_cc, align_a | |||
if tile.math_instruction.element_accumulator == DataType.s32: | |||
epilogues = [EpilogueFunctor.LinearCombinationClamp] | |||
else: | |||
assert tile.math_instruction.element_accumulator == DataType.f32 | |||
assert tile.math_instruction.element_accumulator == DataType.f32 or \ | |||
tile.math_instruction.element_accumulator == DataType.f16 | |||
epilogues = [EpilogueFunctor.LinearCombination] | |||
for epilogue in epilogues: | |||
@@ -799,7 +800,22 @@ class EmitGemmSplitKParallelInstance: | |||
${epilogue_vector_length}, | |||
${element_accumulator}, | |||
${element_epilogue} | |||
> | |||
>, | |||
cutlass::epilogue::thread::Convert< | |||
${element_accumulator}, | |||
${epilogue_vector_length}, | |||
${element_accumulator} | |||
>, | |||
cutlass::reduction::thread::ReduceAdd< | |||
${element_accumulator}, | |||
${element_accumulator}, | |||
${epilogue_vector_length} | |||
>, | |||
cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle, | |||
${stages}, | |||
${align_a}, | |||
${align_b}, | |||
${math_operation} | |||
>; | |||
""" | |||
def emit(self, operation): | |||
@@ -831,7 +847,10 @@ class EmitGemmSplitKParallelInstance: | |||
'epilogue_vector_length': str(epilogue_vector_length), | |||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]), | |||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], | |||
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], | |||
'stages': str(operation.tile_description.stages), | |||
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], | |||
'align_a': str(operation.A.alignment), | |||
'align_b': str(operation.B.alignment), | |||
} | |||
return SubstituteTemplate(self.template, values) | |||
@@ -32,6 +32,8 @@ if __name__ == "__main__": | |||
f.write("# Generated by dnn/scripts/cutlass_generator/gen_list.py\n\n") | |||
f.write("cutlass_gen_list = [\n") | |||
write_op_list(f, "gemm", "simt") | |||
write_op_list(f, "gemm", "tensorop1688") | |||
write_op_list(f, "gemm", "tensorop884") | |||
write_op_list(f, "gemv", "simt") | |||
write_op_list(f, "deconv", "simt") | |||
write_op_list(f, "conv2d", "simt") | |||
@@ -597,6 +597,131 @@ def GenerateGemv_Simt(args): | |||
return operations | |||
# | |||
def GeneratesGemm_TensorOp_1688(args): | |||
layouts = [ | |||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn | |||
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt | |||
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn | |||
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt | |||
] | |||
math_instructions = [ | |||
MathInstruction( \ | |||
[16, 8, 8], \ | |||
DataType.f16, DataType.f16, DataType.f32, \ | |||
OpcodeClass.TensorOp, \ | |||
MathOperation.multiply_add), | |||
MathInstruction( \ | |||
[16, 8, 8], \ | |||
DataType.f16, DataType.f16, DataType.f16, \ | |||
OpcodeClass.TensorOp, \ | |||
MathOperation.multiply_add), | |||
] | |||
min_cc = 75 | |||
max_cc = 1024 | |||
alignment_constraints = [8, 4, 2, | |||
#1 | |||
] | |||
operations = [] | |||
for math_inst in math_instructions: | |||
for layout in layouts: | |||
for align in alignment_constraints: | |||
tile_descriptions = [ | |||
TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
## comment some configuration to reduce compilation time and binary size | |||
# TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
# TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
# TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
] | |||
data_type = [ | |||
math_inst.element_a, | |||
math_inst.element_b, | |||
math_inst.element_a, | |||
math_inst.element_accumulator, | |||
] | |||
for tile in tile_descriptions: | |||
operations += GeneratesGemm(tile, \ | |||
data_type, \ | |||
layout[0], \ | |||
layout[1], \ | |||
layout[2], \ | |||
min_cc, \ | |||
align * 16, \ | |||
align * 16, \ | |||
align * 16) | |||
return operations | |||
# | |||
def GeneratesGemm_TensorOp_884(args): | |||
layouts = [ | |||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn | |||
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt | |||
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn | |||
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt | |||
] | |||
math_instructions = [ | |||
MathInstruction( \ | |||
[8, 8, 4], \ | |||
DataType.f16, DataType.f16, DataType.f32, \ | |||
OpcodeClass.TensorOp, \ | |||
MathOperation.multiply_add), | |||
MathInstruction( \ | |||
[8, 8, 4], \ | |||
DataType.f16, DataType.f16, DataType.f16, \ | |||
OpcodeClass.TensorOp, \ | |||
MathOperation.multiply_add), | |||
] | |||
min_cc = 70 | |||
max_cc = 75 | |||
alignment_constraints = [8, 4, 2, | |||
# 1 | |||
] | |||
operations = [] | |||
for math_inst in math_instructions: | |||
for layout in layouts: | |||
for align in alignment_constraints: | |||
tile_descriptions = [ | |||
TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), | |||
TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
## comment some configuration to reduce compilation time and binary size | |||
# TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
# TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
# TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
] | |||
data_type = [ | |||
math_inst.element_a, | |||
math_inst.element_b, | |||
math_inst.element_a, | |||
math_inst.element_accumulator, | |||
] | |||
for tile in tile_descriptions: | |||
operations += GeneratesGemm(tile, \ | |||
data_type, \ | |||
layout[0], \ | |||
layout[1], \ | |||
layout[2], \ | |||
min_cc, \ | |||
align * 16, \ | |||
align * 16, \ | |||
align * 16) | |||
return operations | |||
# | |||
def GenerateConv2dOperations(args): | |||
if args.type == "simt": | |||
return GenerateConv2d_Simt(args) | |||
@@ -613,9 +738,14 @@ def GenerateDeconvOperations(args): | |||
return GenerateDeconv_Simt(args) | |||
def GenerateGemmOperations(args): | |||
assert args.type == "simt", "operation gemm only support" \ | |||
"simt. (got:{})".format(args.type) | |||
return GenerateGemm_Simt(args) | |||
if args.type == "tensorop884": | |||
return GeneratesGemm_TensorOp_884(args) | |||
elif args.type == "tensorop1688": | |||
return GeneratesGemm_TensorOp_1688(args) | |||
else: | |||
assert args.type == "simt", "operation gemm only support" \ | |||
"simt. (got:{})".format(args.type) | |||
return GenerateGemm_Simt(args) | |||
def GenerateGemvOperations(args): | |||
assert args.type == "simt", "operation gemv only support" \ | |||
@@ -631,7 +761,7 @@ if __name__ == "__main__": | |||
parser.add_argument("--operations", type=str, choices=['gemm', 'gemv', 'conv2d', 'deconv'], | |||
required=True, help="Specifies the operation to generate (gemm, gemv, conv2d, deconv)") | |||
parser.add_argument("output", type=str, help="output directory for CUTLASS kernel files") | |||
parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832'], | |||
parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832', 'tensorop884', 'tensorop1688'], | |||
default='simt', help="kernel type of CUTLASS kernel generator") | |||
gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl" | |||
@@ -138,6 +138,296 @@ cutlass_gen_list = [ | |||
"cutlass_simt_sgemm_256x64_8x2_tt_align1.cu", | |||
"cutlass_simt_sgemm_split_k_parallel_256x64_8x2_tt_align1.cu", | |||
"all_gemm_simt_operations.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nn_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nn_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nn_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nn_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nn_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nn_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nn_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nn_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nn_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nn_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nn_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nn_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nn_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nn_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nn_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nt_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nt_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nt_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nt_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nt_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nt_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nt_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nt_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nt_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nt_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nt_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nt_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nt_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nt_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nt_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nt_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nt_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nt_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tn_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tn_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tn_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tn_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tn_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tn_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tn_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tn_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tn_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tn_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tn_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tn_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tn_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tn_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tn_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tn_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tn_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tn_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tt_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tt_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tt_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tt_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tt_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tt_align8.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tt_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tt_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tt_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tt_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tt_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tt_align4.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tt_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tt_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tt_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tt_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tt_align2.cu", | |||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tt_align2.cu", | |||
"cutlass_tensorop_h1688gemm_256x128_32x2_nn_align8.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nn_align8.cu", | |||
"cutlass_tensorop_h1688gemm_128x256_32x2_nn_align8.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nn_align8.cu", | |||
"cutlass_tensorop_h1688gemm_128x128_32x2_nn_align8.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nn_align8.cu", | |||
"cutlass_tensorop_h1688gemm_256x128_32x2_nn_align4.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nn_align4.cu", | |||
"cutlass_tensorop_h1688gemm_128x256_32x2_nn_align4.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nn_align4.cu", | |||
"cutlass_tensorop_h1688gemm_128x128_32x2_nn_align4.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nn_align4.cu", | |||
"cutlass_tensorop_h1688gemm_256x128_32x2_nn_align2.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nn_align2.cu", | |||
"cutlass_tensorop_h1688gemm_128x256_32x2_nn_align2.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nn_align2.cu", | |||
"cutlass_tensorop_h1688gemm_128x128_32x2_nn_align2.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nn_align2.cu", | |||
"cutlass_tensorop_h1688gemm_256x128_32x2_nt_align8.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nt_align8.cu", | |||
"cutlass_tensorop_h1688gemm_128x256_32x2_nt_align8.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nt_align8.cu", | |||
"cutlass_tensorop_h1688gemm_128x128_32x2_nt_align8.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nt_align8.cu", | |||
"cutlass_tensorop_h1688gemm_256x128_32x2_nt_align4.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nt_align4.cu", | |||
"cutlass_tensorop_h1688gemm_128x256_32x2_nt_align4.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nt_align4.cu", | |||
"cutlass_tensorop_h1688gemm_128x128_32x2_nt_align4.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nt_align4.cu", | |||
"cutlass_tensorop_h1688gemm_256x128_32x2_nt_align2.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nt_align2.cu", | |||
"cutlass_tensorop_h1688gemm_128x256_32x2_nt_align2.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nt_align2.cu", | |||
"cutlass_tensorop_h1688gemm_128x128_32x2_nt_align2.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nt_align2.cu", | |||
"cutlass_tensorop_h1688gemm_256x128_32x2_tn_align8.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tn_align8.cu", | |||
"cutlass_tensorop_h1688gemm_128x256_32x2_tn_align8.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tn_align8.cu", | |||
"cutlass_tensorop_h1688gemm_128x128_32x2_tn_align8.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tn_align8.cu", | |||
"cutlass_tensorop_h1688gemm_256x128_32x2_tn_align4.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tn_align4.cu", | |||
"cutlass_tensorop_h1688gemm_128x256_32x2_tn_align4.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tn_align4.cu", | |||
"cutlass_tensorop_h1688gemm_128x128_32x2_tn_align4.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tn_align4.cu", | |||
"cutlass_tensorop_h1688gemm_256x128_32x2_tn_align2.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tn_align2.cu", | |||
"cutlass_tensorop_h1688gemm_128x256_32x2_tn_align2.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tn_align2.cu", | |||
"cutlass_tensorop_h1688gemm_128x128_32x2_tn_align2.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tn_align2.cu", | |||
"cutlass_tensorop_h1688gemm_256x128_32x2_tt_align8.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tt_align8.cu", | |||
"cutlass_tensorop_h1688gemm_128x256_32x2_tt_align8.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tt_align8.cu", | |||
"cutlass_tensorop_h1688gemm_128x128_32x2_tt_align8.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tt_align8.cu", | |||
"cutlass_tensorop_h1688gemm_256x128_32x2_tt_align4.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tt_align4.cu", | |||
"cutlass_tensorop_h1688gemm_128x256_32x2_tt_align4.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tt_align4.cu", | |||
"cutlass_tensorop_h1688gemm_128x128_32x2_tt_align4.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tt_align4.cu", | |||
"cutlass_tensorop_h1688gemm_256x128_32x2_tt_align2.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tt_align2.cu", | |||
"cutlass_tensorop_h1688gemm_128x256_32x2_tt_align2.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tt_align2.cu", | |||
"cutlass_tensorop_h1688gemm_128x128_32x2_tt_align2.cu", | |||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tt_align2.cu", | |||
"all_gemm_tensorop1688_operations.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nn_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nn_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nn_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nn_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nn_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nn_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nn_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nn_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nn_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nn_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nn_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nn_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nn_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nn_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nn_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nn_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nn_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nn_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nt_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nt_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nt_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nt_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nt_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nt_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nt_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nt_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nt_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nt_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nt_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nt_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nt_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nt_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nt_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nt_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nt_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nt_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tn_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tn_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tn_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tn_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tn_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tn_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tn_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tn_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tn_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tn_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tn_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tn_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tn_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tn_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tn_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tn_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tn_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tn_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tt_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tt_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tt_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tt_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tt_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tt_align8.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tt_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tt_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tt_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tt_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tt_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tt_align4.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tt_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tt_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tt_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tt_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tt_align2.cu", | |||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tt_align2.cu", | |||
"cutlass_tensorop_h884gemm_256x128_32x2_nn_align8.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nn_align8.cu", | |||
"cutlass_tensorop_h884gemm_128x256_32x2_nn_align8.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nn_align8.cu", | |||
"cutlass_tensorop_h884gemm_128x128_32x2_nn_align8.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nn_align8.cu", | |||
"cutlass_tensorop_h884gemm_256x128_32x2_nn_align4.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nn_align4.cu", | |||
"cutlass_tensorop_h884gemm_128x256_32x2_nn_align4.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nn_align4.cu", | |||
"cutlass_tensorop_h884gemm_128x128_32x2_nn_align4.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nn_align4.cu", | |||
"cutlass_tensorop_h884gemm_256x128_32x2_nn_align2.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nn_align2.cu", | |||
"cutlass_tensorop_h884gemm_128x256_32x2_nn_align2.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nn_align2.cu", | |||
"cutlass_tensorop_h884gemm_128x128_32x2_nn_align2.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nn_align2.cu", | |||
"cutlass_tensorop_h884gemm_256x128_32x2_nt_align8.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nt_align8.cu", | |||
"cutlass_tensorop_h884gemm_128x256_32x2_nt_align8.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nt_align8.cu", | |||
"cutlass_tensorop_h884gemm_128x128_32x2_nt_align8.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nt_align8.cu", | |||
"cutlass_tensorop_h884gemm_256x128_32x2_nt_align4.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nt_align4.cu", | |||
"cutlass_tensorop_h884gemm_128x256_32x2_nt_align4.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nt_align4.cu", | |||
"cutlass_tensorop_h884gemm_128x128_32x2_nt_align4.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nt_align4.cu", | |||
"cutlass_tensorop_h884gemm_256x128_32x2_nt_align2.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nt_align2.cu", | |||
"cutlass_tensorop_h884gemm_128x256_32x2_nt_align2.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nt_align2.cu", | |||
"cutlass_tensorop_h884gemm_128x128_32x2_nt_align2.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nt_align2.cu", | |||
"cutlass_tensorop_h884gemm_256x128_32x2_tn_align8.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tn_align8.cu", | |||
"cutlass_tensorop_h884gemm_128x256_32x2_tn_align8.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tn_align8.cu", | |||
"cutlass_tensorop_h884gemm_128x128_32x2_tn_align8.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tn_align8.cu", | |||
"cutlass_tensorop_h884gemm_256x128_32x2_tn_align4.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tn_align4.cu", | |||
"cutlass_tensorop_h884gemm_128x256_32x2_tn_align4.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tn_align4.cu", | |||
"cutlass_tensorop_h884gemm_128x128_32x2_tn_align4.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tn_align4.cu", | |||
"cutlass_tensorop_h884gemm_256x128_32x2_tn_align2.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tn_align2.cu", | |||
"cutlass_tensorop_h884gemm_128x256_32x2_tn_align2.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tn_align2.cu", | |||
"cutlass_tensorop_h884gemm_128x128_32x2_tn_align2.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tn_align2.cu", | |||
"cutlass_tensorop_h884gemm_256x128_32x2_tt_align8.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tt_align8.cu", | |||
"cutlass_tensorop_h884gemm_128x256_32x2_tt_align8.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tt_align8.cu", | |||
"cutlass_tensorop_h884gemm_128x128_32x2_tt_align8.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tt_align8.cu", | |||
"cutlass_tensorop_h884gemm_256x128_32x2_tt_align4.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tt_align4.cu", | |||
"cutlass_tensorop_h884gemm_128x256_32x2_tt_align4.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tt_align4.cu", | |||
"cutlass_tensorop_h884gemm_128x128_32x2_tt_align4.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tt_align4.cu", | |||
"cutlass_tensorop_h884gemm_256x128_32x2_tt_align2.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tt_align2.cu", | |||
"cutlass_tensorop_h884gemm_128x256_32x2_tt_align2.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tt_align2.cu", | |||
"cutlass_tensorop_h884gemm_128x128_32x2_tt_align2.cu", | |||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tt_align2.cu", | |||
"all_gemm_tensorop884_operations.cu", | |||
"cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4.cu", | |||
"cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2.cu", | |||
"cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1.cu", | |||
@@ -646,4 +936,4 @@ cutlass_gen_list = [ | |||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
"all_conv2d_tensorop8832_operations.cu", | |||
] | |||
] |
@@ -151,6 +151,8 @@ if(MGE_WITH_CUDA) | |||
set(${gen_files} "${${gen_files}}" PARENT_SCOPE) | |||
endfunction() | |||
gen_cutlass_kimpl(gemm simt CUTLASS_SOURCES) | |||
gen_cutlass_kimpl(gemm tensorop884 CUTLASS_SOURCES) | |||
gen_cutlass_kimpl(gemm tensorop1688 CUTLASS_SOURCES) | |||
gen_cutlass_kimpl(gemv simt CUTLASS_SOURCES) | |||
gen_cutlass_kimpl(deconv simt CUTLASS_SOURCES) | |||
gen_cutlass_kimpl(conv2d simt CUTLASS_SOURCES) | |||
@@ -49,6 +49,8 @@ namespace library { | |||
(__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||
void initialize_all_gemm_simt_operations(Manifest& manifest); | |||
void initialize_all_gemm_tensorop884_operations(Manifest& manifest); | |||
void initialize_all_gemm_tensorop1688_operations(Manifest& manifest); | |||
void initialize_all_conv2d_simt_operations(Manifest& manifest); | |||
void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest); | |||
void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest); | |||
@@ -56,6 +58,8 @@ void initialize_all_deconv_simt_operations(Manifest& manifest); | |||
void initialize_all(Manifest& manifest) { | |||
initialize_all_gemm_simt_operations(manifest); | |||
initialize_all_gemm_tensorop884_operations(manifest); | |||
initialize_all_gemm_tensorop1688_operations(manifest); | |||
initialize_all_conv2d_simt_operations(manifest); | |||
initialize_all_conv2d_tensorop8816_operations(manifest); | |||
initialize_all_conv2d_tensorop8832_operations(manifest); | |||
@@ -55,6 +55,8 @@ GemmKey get_gemm_key_from_desc(const GemmDescription& desc) { | |||
key.layout_B = desc.B.layout; | |||
key.element_C = desc.C.element; | |||
key.layout_C = desc.C.layout; | |||
key.element_accumulator = | |||
desc.tile_description.math_instruction.element_accumulator; | |||
key.threadblock_shape_m = desc.tile_description.threadblock_shape.m(); | |||
key.threadblock_shape_n = desc.tile_description.threadblock_shape.n(); | |||
@@ -75,6 +77,8 @@ GemmKey get_gemm_key_from_desc(const GemmDescription& desc) { | |||
desc.tile_description.math_instruction.instruction_shape.k(); | |||
key.stages = desc.stages; | |||
key.alignment_A = desc.A.alignment; | |||
key.alignment_B = desc.B.alignment; | |||
key.split_k_mode = desc.split_k_mode; | |||
return key; | |||
@@ -77,6 +77,7 @@ struct GemmKey { | |||
LayoutTypeID layout_B; | |||
NumericTypeID element_C; | |||
LayoutTypeID layout_C; | |||
NumericTypeID element_accumulator; | |||
int threadblock_shape_m; | |||
int threadblock_shape_n; | |||
@@ -91,12 +92,15 @@ struct GemmKey { | |||
int instruction_shape_k; | |||
int stages; | |||
int alignment_A; | |||
int alignment_B; | |||
SplitKMode split_k_mode; | |||
inline bool operator==(GemmKey const& rhs) const { | |||
return (element_A == rhs.element_A) && (layout_A == rhs.layout_A) && | |||
(element_B == rhs.element_B) && (layout_B == rhs.layout_B) && | |||
(element_C == rhs.element_C) && (layout_C == rhs.layout_C) && | |||
(element_accumulator == rhs.element_accumulator) && | |||
(threadblock_shape_m == rhs.threadblock_shape_m) && | |||
(threadblock_shape_n == rhs.threadblock_shape_n) && | |||
(threadblock_shape_k == rhs.threadblock_shape_k) && | |||
@@ -106,7 +110,9 @@ struct GemmKey { | |||
(instruction_shape_m == rhs.instruction_shape_m) && | |||
(instruction_shape_n == rhs.instruction_shape_n) && | |||
(instruction_shape_k == rhs.instruction_shape_k) && | |||
(stages == rhs.stages) && (split_k_mode == rhs.split_k_mode); | |||
(stages == rhs.stages) && (alignment_A == rhs.alignment_A) && | |||
(alignment_B == rhs.alignment_B) && | |||
(split_k_mode == rhs.split_k_mode); | |||
} | |||
inline bool operator!=(GemmKey const& rhs) const { return !(*this == rhs); } | |||
@@ -130,10 +136,13 @@ struct GemmKey { | |||
"\n layout_B: " + to_string(layout_B) + | |||
"\n element_C: " + to_string(element_C) + | |||
"\n layout_C: " + to_string(layout_C) + | |||
"\n element_accumulator: " + to_string(element_accumulator) + | |||
"\n threadblock_shape: " + threadblock_shape_str + | |||
"\n warp_shape: " + warp_shape_str + | |||
"\n instruction_shape: " + instruction_shape_str + | |||
"\n stages: " + std::to_string(stages) + | |||
"\n alignment_A: " + std::to_string(alignment_A) + | |||
"\n alignment_B: " + std::to_string(alignment_B) + | |||
"\n split_k_mode: " + to_string(split_k_mode) + "\n}"; | |||
} | |||
}; | |||
@@ -147,6 +156,8 @@ struct GemmKeyHasher { | |||
.update(&key.layout_B, sizeof(key.layout_B)) | |||
.update(&key.element_C, sizeof(key.element_C)) | |||
.update(&key.layout_C, sizeof(key.layout_C)) | |||
.update(&key.element_accumulator, | |||
sizeof(key.element_accumulator)) | |||
.update(&key.threadblock_shape_m, | |||
sizeof(key.threadblock_shape_m)) | |||
.update(&key.threadblock_shape_n, | |||
@@ -157,6 +168,8 @@ struct GemmKeyHasher { | |||
.update(&key.warp_shape_n, sizeof(key.warp_shape_n)) | |||
.update(&key.warp_shape_k, sizeof(key.warp_shape_k)) | |||
.update(&key.stages, sizeof(key.stages)) | |||
.update(&key.alignment_A, sizeof(key.alignment_A)) | |||
.update(&key.alignment_B, sizeof(key.alignment_B)) | |||
.update(&key.split_k_mode, sizeof(key.split_k_mode)) | |||
.digest(); | |||
} | |||
@@ -43,6 +43,12 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||
for (auto&& algo : simt_float32_gemv_batched_strided) { | |||
all_algos.push_back(&algo); | |||
} | |||
for (auto&& algo : tensorop_float16) { | |||
all_algos.push_back(&algo); | |||
} | |||
for (auto&& algo : tensorop_float16_split_k) { | |||
all_algos.push_back(&algo); | |||
} | |||
#endif | |||
all_algos.push_back(&naive); | |||
@@ -53,7 +59,7 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||
#if CUDA_VERSION >= 9020 | |||
void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() { | |||
using AlgoParam = AlgoFloat32SIMT::AlgoParam; | |||
using AlgoParam = AlgoCutlassMatrixMulBase::AlgoParam; | |||
simt_float32.emplace_back(AlgoParam{64, 256, 8, 32, 64, 8}); | |||
simt_float32.emplace_back(AlgoParam{256, 64, 8, 64, 32, 8}); | |||
simt_float32.emplace_back(AlgoParam{32, 256, 8, 16, 64, 8}); | |||
@@ -91,6 +97,19 @@ void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() { | |||
simt_float32_gemv_batched_strided.emplace_back(128); | |||
simt_float32_gemv_batched_strided.emplace_back(64); | |||
simt_float32_gemv_batched_strided.emplace_back(32); | |||
#define FOREACH_CUTLASS_MATMUL_F16_SHAPES(cb) \ | |||
cb(256, 128, 32, 64, 64, 32, 8, 8, 4); \ | |||
cb(128, 256, 32, 64, 64, 32, 8, 8, 4); \ | |||
cb(128, 128, 32, 64, 64, 32, 8, 8, 4); \ | |||
cb(256, 128, 32, 64, 64, 32, 16, 8, 8); \ | |||
cb(128, 256, 32, 64, 64, 32, 16, 8, 8); \ | |||
cb(128, 128, 32, 64, 64, 32, 16, 8, 8); | |||
#define cb(...) \ | |||
tensorop_float16.emplace_back(AlgoParam{__VA_ARGS__}); \ | |||
tensorop_float16_split_k.emplace_back(AlgoParam{__VA_ARGS__}); | |||
FOREACH_CUTLASS_MATMUL_F16_SHAPES(cb) | |||
#undef cb | |||
#undef FOREACH_CUTLASS_MATMUL_F16_SHAPES | |||
} | |||
#endif | |||
@@ -41,11 +41,13 @@ public: | |||
CUDA_WMMA_UINT4X4X32, | |||
CUDA_CUBLASLT, | |||
CUDA_NAIVE, | |||
CUDA_BFLOAT16, | |||
CUDA_BFLOAT16, | |||
#if CUDA_VERSION >= 9020 | |||
CUDA_FLOAT32_SIMT, | |||
CUDA_FLOAT32_SIMT_SPLIT_K, | |||
CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED, | |||
CUDA_FLOAT32_SIMT, | |||
CUDA_FLOAT32_SIMT_SPLIT_K, | |||
CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED, | |||
CUDA_FLOAT16_TENSOR_OP, | |||
CUDA_FLOAT16_TENSOR_OP_SPLIT_K, | |||
#endif | |||
}; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
@@ -188,65 +190,83 @@ private: | |||
#endif | |||
#if CUDA_VERSION >= 9020 | |||
class MatrixMulForwardImpl::AlgoFloat32SIMT final : public AlgoBase { | |||
class MatrixMulForwardImpl::AlgoCutlassMatrixMulBase : public AlgoBase { | |||
public: | |||
struct AlgoParam { | |||
int threadblock_m, threadblock_n, threadblock_k; | |||
int warp_m, warp_n, warp_k; | |||
std::string to_string() { | |||
return ssprintf("%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n, | |||
threadblock_k, warp_m, warp_n, warp_k); | |||
} | |||
int instruction_m, instruction_n, instruction_k; | |||
AlgoParam(int threadblock_m_, int threadblock_n_, int threadblock_k_, | |||
int warp_m_, int warp_n_, int warp_k_, int instruction_m_ = 1, | |||
int instruction_n_ = 1, int instruction_k_ = 1) | |||
: threadblock_m{threadblock_m_}, | |||
threadblock_n{threadblock_n_}, | |||
threadblock_k{threadblock_k_}, | |||
warp_m{warp_m_}, | |||
warp_n{warp_n_}, | |||
warp_k{warp_k_}, | |||
instruction_m{instruction_m_}, | |||
instruction_n{instruction_n_}, | |||
instruction_k{instruction_k_} {} | |||
std::string to_string() const; | |||
}; | |||
AlgoCutlassMatrixMulBase(AlgoParam algo_param) : m_algo_param{algo_param} {} | |||
void exec(const ExecArgs& args) const override; | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_algo_param, ret); | |||
return ret; | |||
} | |||
protected: | |||
virtual int min_alignment_requirement() const = 0; | |||
virtual void do_exec(const ExecArgs& args) const = 0; | |||
std::pair<bool, TensorLayoutArray> construct_aligned_layouts( | |||
const SizeArgs& args) const; | |||
int max_alignment(const SizeArgs& args) const; | |||
AlgoParam m_algo_param; | |||
}; | |||
class MatrixMulForwardImpl::AlgoFloat32SIMT final | |||
: public AlgoCutlassMatrixMulBase { | |||
public: | |||
AlgoFloat32SIMT(AlgoParam algo_param) | |||
: m_algo_param{algo_param}, | |||
: AlgoCutlassMatrixMulBase{algo_param}, | |||
m_name{ssprintf("CUTLASS_FLOAT32_SIMT_%s", | |||
m_algo_param.to_string().c_str())} {} | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
const char* name() const override { return m_name.c_str(); } | |||
void exec(const ExecArgs& args) const override; | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_algo_param, ret); | |||
return ret; | |||
} | |||
private: | |||
AlgoParam m_algo_param; | |||
void do_exec(const ExecArgs& args) const override; | |||
int min_alignment_requirement() const override { return 1; } | |||
std::string m_name; | |||
}; | |||
class MatrixMulForwardImpl::AlgoFloat32SIMTSplitK final : public AlgoBase { | |||
class MatrixMulForwardImpl::AlgoFloat32SIMTSplitK final | |||
: public AlgoCutlassMatrixMulBase { | |||
public: | |||
using AlgoParam = MatrixMulForwardImpl::AlgoFloat32SIMT::AlgoParam; | |||
AlgoFloat32SIMTSplitK(AlgoParam algo_param) | |||
: m_algo_param{algo_param}, | |||
: AlgoCutlassMatrixMulBase{algo_param}, | |||
m_name{ssprintf("CUTLASS_FLOAT32_SIMT_SPLIT_K_%s", | |||
m_algo_param.to_string().c_str())} {} | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
const char* name() const override { return m_name.c_str(); } | |||
void exec(const ExecArgs& args) const override; | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_algo_param, ret); | |||
return ret; | |||
} | |||
private: | |||
AlgoParam m_algo_param; | |||
void do_exec(const ExecArgs& args) const override; | |||
int min_alignment_requirement() const override { return 1; } | |||
std::string m_name; | |||
}; | |||
@@ -276,6 +296,56 @@ private: | |||
int m_threadblock_n; | |||
std::string m_name; | |||
}; | |||
class MatrixMulForwardImpl::AlgoFloat16TensorOp final | |||
: public AlgoCutlassMatrixMulBase { | |||
public: | |||
AlgoFloat16TensorOp(AlgoParam algo_param) | |||
: AlgoCutlassMatrixMulBase{algo_param}, | |||
m_name{ssprintf("CUTLASS_FLOAT16_TENSOR_OP_h%d%d%d_%s", | |||
m_algo_param.instruction_m, | |||
m_algo_param.instruction_n, | |||
m_algo_param.instruction_k, | |||
m_algo_param.to_string().c_str())} {} | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
const char* name() const override { return m_name.c_str(); } | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT16_TENSOR_OP) | |||
private: | |||
void do_exec(const ExecArgs& args) const override; | |||
int min_alignment_requirement() const override { return 2; } | |||
std::string m_name; | |||
}; | |||
class MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK final | |||
: public AlgoCutlassMatrixMulBase { | |||
public: | |||
AlgoFloat16TensorOpSplitK(AlgoParam algo_param) | |||
: AlgoCutlassMatrixMulBase{algo_param}, | |||
m_name{ssprintf("CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h%d%d%d_%s", | |||
m_algo_param.instruction_m, | |||
m_algo_param.instruction_n, | |||
m_algo_param.instruction_k, | |||
m_algo_param.to_string().c_str())} {} | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
const char* name() const override { return m_name.c_str(); } | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT16_TENSOR_OP_SPLIT_K) | |||
private: | |||
void do_exec(const ExecArgs& args) const override; | |||
int min_alignment_requirement() const override { return 2; } | |||
std::string m_name; | |||
}; | |||
#endif | |||
class MatrixMulForwardImpl::AlgoPack : NonCopyableObj { | |||
@@ -300,6 +370,8 @@ public: | |||
std::vector<AlgoFloat32SIMTSplitK> simt_float32_split_k; | |||
std::vector<AlgoFloat32SIMTGemvBatchedStrided> | |||
simt_float32_gemv_batched_strided; | |||
std::vector<AlgoFloat16TensorOp> tensorop_float16; | |||
std::vector<AlgoFloat16TensorOpSplitK> tensorop_float16_split_k; | |||
#endif | |||
std::vector<AlgoBase*> all_algos; | |||
@@ -0,0 +1,154 @@ | |||
/** | |||
* \file dnn/src/cuda/matrix_mul/cutlass_float16_tensorop.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/cuda/cutlass/singleton.h" | |||
#include "src/cuda/handle.h" | |||
#include "src/cuda/matrix_mul/algos.h" | |||
#include "src/cuda/utils.h" | |||
#if CUDA_VERSION >= 9020 | |||
using namespace megdnn; | |||
using namespace cuda; | |||
bool MatrixMulForwardImpl::AlgoFloat16TensorOp::is_available( | |||
const SizeArgs& args) const { | |||
bool available = | |||
args.opr->param().format == param::MatrixMul::Format::DEFAULT && | |||
args.layout_b.dtype == dtype::Float16() && | |||
args.layout_c.dtype == dtype::Float16(); | |||
int n = args.layout_c.shape[1]; | |||
auto&& device_prop = cuda::current_device_prop(); | |||
int y_grid_limit = device_prop.maxGridSize[1]; | |||
// limit y grid | |||
available &= ((n + m_algo_param.threadblock_n - 1) / | |||
m_algo_param.threadblock_n <= | |||
y_grid_limit); | |||
if (m_algo_param.instruction_m == 8 && m_algo_param.instruction_n == 8 && | |||
m_algo_param.instruction_k == 4) { | |||
available &= is_compute_capability_required(7, 0); | |||
} else { | |||
megdnn_assert(m_algo_param.instruction_m == 16 && | |||
m_algo_param.instruction_n == 8 && | |||
m_algo_param.instruction_k == 8); | |||
available &= is_compute_capability_required(7, 5); | |||
} | |||
return available; | |||
} | |||
size_t MatrixMulForwardImpl::AlgoFloat16TensorOp::get_workspace_in_bytes( | |||
const SizeArgs& args) const { | |||
auto aligned = construct_aligned_layouts(args); | |||
if (!aligned.first) | |||
return 0_z; | |||
const auto& layouts = aligned.second; | |||
size_t ws_size = 0; | |||
for (auto&& ly : layouts) { | |||
ws_size += ly.span().dist_byte(); | |||
} | |||
return ws_size; | |||
} | |||
void MatrixMulForwardImpl::AlgoFloat16TensorOp::do_exec( | |||
const ExecArgs& args) const { | |||
int64_t lda = args.tensor_a.layout.stride[0], | |||
ldb = args.tensor_b.layout.stride[0], | |||
ldc = args.tensor_c.layout.stride[0]; | |||
int alignment = max_alignment(args); | |||
int min_alignment = min_alignment_requirement(); | |||
auto&& param = args.opr->param(); | |||
int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | |||
k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | |||
megdnn_assert(lda % alignment == 0 && ldb % alignment == 0 && | |||
ldc % alignment == 0 && m % alignment == 0 && | |||
n % alignment == 0 && k % alignment == 0 && | |||
alignment >= min_alignment); | |||
cutlass::gemm::GemmCoord problem_size{m, n, k}; | |||
auto&& stream = cuda_stream(args.opr->handle()); | |||
int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | |||
// \note these constants (i.e. one and zero) of cutlass epilogue will be | |||
// passed by pointers and interpreted as ElementCompute*, which will be used | |||
// to initialize kernel parameters. So the arguments' type on the host side | |||
// should be the same as the ElementCompute of kernel instance, otherwise | |||
// undefined kernel bahaviors will occur caused by incorrect intepretation | |||
// of these pointers. | |||
float one = 1.f, zero = 0.f; | |||
dt_float16 one_f16 = static_cast<dt_float16>(one), | |||
zero_f16 = static_cast<dt_float16>(zero); | |||
using namespace cutlass::library; | |||
auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor | |||
: LayoutTypeID::kRowMajor; | |||
auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | |||
: LayoutTypeID::kRowMajor; | |||
void *host_one, *host_zero; | |||
NumericTypeID element_accumulator; | |||
if (param.compute_mode == param::MatrixMul::ComputeMode::DEFAULT) { | |||
element_accumulator = NumericTypeID::kF16; | |||
host_one = &one_f16; | |||
host_zero = &zero_f16; | |||
} else { | |||
megdnn_assert(param.compute_mode == | |||
param::MatrixMul::ComputeMode::FLOAT32); | |||
element_accumulator = NumericTypeID::kF32; | |||
host_one = &one; | |||
host_zero = &zero; | |||
} | |||
GemmKey key{NumericTypeID::kF16, | |||
layoutA, | |||
NumericTypeID::kF16, | |||
layoutB, | |||
NumericTypeID::kF16, | |||
LayoutTypeID::kRowMajor, | |||
element_accumulator, | |||
m_algo_param.threadblock_m, | |||
m_algo_param.threadblock_n, | |||
m_algo_param.threadblock_k, | |||
m_algo_param.warp_m, | |||
m_algo_param.warp_n, | |||
m_algo_param.warp_k, | |||
m_algo_param.instruction_m, | |||
m_algo_param.instruction_n, | |||
m_algo_param.instruction_k, | |||
2, | |||
alignment, | |||
alignment, | |||
SplitKMode::kNone}; | |||
const auto& table = Singleton::get().operation_table; | |||
megdnn_assert(table.gemm_operations.count(key) > 0, | |||
"key not found in cutlass operation table"); | |||
const auto& ops = table.gemm_operations.at(key); | |||
megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", | |||
ops.size()); | |||
GemmArguments gemm_args{problem_size, | |||
args.tensor_a.raw_ptr, | |||
args.tensor_b.raw_ptr, | |||
args.tensor_c.raw_ptr, | |||
args.tensor_c.raw_ptr, | |||
lda, | |||
ldb, | |||
ldc, | |||
ldc, | |||
1, | |||
host_one, | |||
host_zero}; | |||
cutlass_check(ops[0]->run(&gemm_args, workspace, stream)); | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,165 @@ | |||
/** | |||
* \file dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/cuda/cutlass/singleton.h" | |||
#include "src/cuda/handle.h" | |||
#include "src/cuda/matrix_mul/algos.h" | |||
#include "src/cuda/utils.h" | |||
#if CUDA_VERSION >= 9020 | |||
using namespace megdnn; | |||
using namespace cuda; | |||
bool MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::is_available( | |||
const SizeArgs& args) const { | |||
auto&& param = args.opr->param(); | |||
int n = args.layout_c.shape[1], | |||
k = args.layout_a.shape[param.transposeA ? 0 : 1]; | |||
bool available = | |||
args.opr->param().format == param::MatrixMul::Format::DEFAULT && | |||
args.layout_a.dtype == dtype::Float16() && | |||
args.layout_b.dtype == dtype::Float16() && | |||
args.layout_c.dtype == dtype::Float16() && k > n; | |||
auto&& device_prop = cuda::current_device_prop(); | |||
int y_grid_limit = device_prop.maxGridSize[1]; | |||
// limit y grid | |||
available &= ((n + m_algo_param.threadblock_n - 1) / | |||
m_algo_param.threadblock_n <= | |||
y_grid_limit); | |||
if (m_algo_param.instruction_m == 8 && m_algo_param.instruction_n == 8 && | |||
m_algo_param.instruction_k == 4) { | |||
available &= is_compute_capability_required(7, 0); | |||
} else { | |||
megdnn_assert(m_algo_param.instruction_m == 16 && | |||
m_algo_param.instruction_n == 8 && | |||
m_algo_param.instruction_k == 8); | |||
available &= is_compute_capability_required(7, 5); | |||
} | |||
return available; | |||
} | |||
size_t MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::get_workspace_in_bytes( | |||
const SizeArgs& args) const { | |||
auto aligned = construct_aligned_layouts(args); | |||
auto&& param = args.opr->param(); | |||
int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | |||
k = args.layout_a.shape[param.transposeA ? 0 : 1]; | |||
int split_k_slices = std::max(1, k / n); | |||
if (!aligned.first) | |||
return args.layout_c.dtype.size(m * n * split_k_slices); | |||
const auto& layouts = aligned.second; | |||
int align_m = layouts[2].shape[0], align_n = layouts[2].shape[1], | |||
align_k = layouts[0].shape[1]; | |||
split_k_slices = std::max(1, align_k / align_n); | |||
size_t ws_size = | |||
args.layout_c.dtype.size(align_m * align_n * split_k_slices); | |||
for (auto&& ly : layouts) | |||
ws_size += ly.span().dist_byte(); | |||
return ws_size; | |||
} | |||
void MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::do_exec( | |||
const ExecArgs& args) const { | |||
int64_t lda = args.tensor_a.layout.stride[0], | |||
ldb = args.tensor_b.layout.stride[0], | |||
ldc = args.tensor_c.layout.stride[0]; | |||
int alignment = max_alignment(args); | |||
int min_alignment = min_alignment_requirement(); | |||
auto&& param = args.opr->param(); | |||
int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | |||
k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | |||
megdnn_assert(lda % alignment == 0 && ldb % alignment == 0 && | |||
ldc % alignment == 0 && m % alignment == 0 && | |||
n % alignment == 0 && k % alignment == 0 && | |||
alignment >= min_alignment); | |||
cutlass::gemm::GemmCoord problem_size{m, n, k}; | |||
int split_k_slices = std::max(1, k / n); | |||
auto&& stream = cuda_stream(args.opr->handle()); | |||
int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | |||
// \note these constants (i.e. one and zero) of cutlass epilogue will be | |||
// passed by pointers and interpreted as ElementCompute*, which will be used | |||
// to initialize kernel parameters. So the arguments' type on the host side | |||
// should be the same as the ElementCompute of kernel instance, otherwise | |||
// undefined kernel bahaviors will occur caused by incorrect intepretation | |||
// of these pointers. | |||
float one = 1.f, zero = 0.f; | |||
dt_float16 one_f16 = static_cast<dt_float16>(one), | |||
zero_f16 = static_cast<dt_float16>(zero); | |||
using namespace cutlass::library; | |||
auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor | |||
: LayoutTypeID::kRowMajor; | |||
auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | |||
: LayoutTypeID::kRowMajor; | |||
void *host_one, *host_zero; | |||
NumericTypeID element_accumulator; | |||
if (param.compute_mode == param::MatrixMul::ComputeMode::DEFAULT) { | |||
element_accumulator = NumericTypeID::kF16; | |||
host_one = &one_f16; | |||
host_zero = &zero_f16; | |||
} else { | |||
megdnn_assert(param.compute_mode == | |||
param::MatrixMul::ComputeMode::FLOAT32); | |||
element_accumulator = NumericTypeID::kF32; | |||
host_one = &one; | |||
host_zero = &zero; | |||
} | |||
GemmKey key{NumericTypeID::kF16, | |||
layoutA, | |||
NumericTypeID::kF16, | |||
layoutB, | |||
NumericTypeID::kF16, | |||
LayoutTypeID::kRowMajor, | |||
element_accumulator, | |||
m_algo_param.threadblock_m, | |||
m_algo_param.threadblock_n, | |||
m_algo_param.threadblock_k, | |||
m_algo_param.warp_m, | |||
m_algo_param.warp_n, | |||
m_algo_param.warp_k, | |||
m_algo_param.instruction_m, | |||
m_algo_param.instruction_n, | |||
m_algo_param.instruction_k, | |||
2, | |||
alignment, | |||
alignment, | |||
SplitKMode::kParallel}; | |||
const auto& table = Singleton::get().operation_table; | |||
megdnn_assert(table.gemm_operations.count(key) > 0, | |||
"key not found in cutlass operation table"); | |||
const auto& ops = table.gemm_operations.at(key); | |||
megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", | |||
ops.size()); | |||
GemmArguments gemm_args{problem_size, | |||
args.tensor_a.raw_ptr, | |||
args.tensor_b.raw_ptr, | |||
args.tensor_c.raw_ptr, | |||
args.tensor_c.raw_ptr, | |||
lda, | |||
ldb, | |||
ldc, | |||
ldc, | |||
split_k_slices, | |||
host_one, | |||
host_zero}; | |||
cutlass_check(ops[0]->run(&gemm_args, workspace, stream)); | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -42,7 +42,8 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes( | |||
return 0_z; | |||
} | |||
void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { | |||
void MatrixMulForwardImpl::AlgoFloat32SIMT::do_exec( | |||
const ExecArgs& args) const { | |||
int64_t lda = args.tensor_a.layout.stride[0], | |||
ldb = args.tensor_b.layout.stride[0], | |||
ldc = args.tensor_c.layout.stride[0]; | |||
@@ -65,12 +66,14 @@ void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { | |||
auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | |||
: LayoutTypeID::kRowMajor; | |||
int alignment = min_alignment_requirement(); | |||
GemmKey key{NumericTypeID::kF32, | |||
layoutA, | |||
NumericTypeID::kF32, | |||
layoutB, | |||
NumericTypeID::kF32, | |||
LayoutTypeID::kRowMajor, | |||
NumericTypeID::kF32, | |||
m_algo_param.threadblock_m, | |||
m_algo_param.threadblock_n, | |||
m_algo_param.threadblock_k, | |||
@@ -79,8 +82,10 @@ void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { | |||
m_algo_param.warp_k, | |||
1, | |||
1, | |||
1, | |||
2, | |||
1, | |||
2, | |||
alignment, | |||
alignment, | |||
SplitKMode::kNone}; | |||
const Operation* op = Singleton::get().operation_table.find_op(key); | |||
@@ -22,7 +22,7 @@ using namespace cuda; | |||
bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( | |||
const SizeArgs& args) const { | |||
auto&& param = args.opr->param(); | |||
int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | |||
int n = args.layout_c.shape[1], | |||
k = args.layout_a.shape[param.transposeA ? 0 : 1]; | |||
bool available = | |||
args.opr->param().format == param::MatrixMul::Format::DEFAULT && | |||
@@ -32,8 +32,8 @@ bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( | |||
auto&& device_prop = cuda::current_device_prop(); | |||
int y_grid_limit = device_prop.maxGridSize[1]; | |||
// limit y grid | |||
available &= ((m + m_algo_param.threadblock_m - 1) / | |||
m_algo_param.threadblock_m <= | |||
available &= ((n + m_algo_param.threadblock_n - 1) / | |||
m_algo_param.threadblock_n <= | |||
y_grid_limit); | |||
return available; | |||
} | |||
@@ -47,7 +47,7 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( | |||
return args.layout_c.dtype.size(m * n * split_k_slices); | |||
} | |||
void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( | |||
void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::do_exec( | |||
const ExecArgs& args) const { | |||
int64_t lda = args.tensor_a.layout.stride[0], | |||
ldb = args.tensor_b.layout.stride[0], | |||
@@ -72,12 +72,14 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( | |||
auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | |||
: LayoutTypeID::kRowMajor; | |||
int alignment = min_alignment_requirement(); | |||
GemmKey key{NumericTypeID::kF32, | |||
layoutA, | |||
NumericTypeID::kF32, | |||
layoutB, | |||
NumericTypeID::kF32, | |||
LayoutTypeID::kRowMajor, | |||
NumericTypeID::kF32, | |||
m_algo_param.threadblock_m, | |||
m_algo_param.threadblock_n, | |||
m_algo_param.threadblock_k, | |||
@@ -87,7 +89,9 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( | |||
1, | |||
1, | |||
1, | |||
2, | |||
2, | |||
alignment, | |||
alignment, | |||
SplitKMode::kParallel}; | |||
Operation const* op = Singleton::get().operation_table.find_op(key); | |||
@@ -0,0 +1,136 @@ | |||
/** | |||
* \file dnn/src/cuda/matrix_mul/cutlass_matrix_mul_base.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/cuda/handle.h" | |||
#include "src/cuda/matrix_mul/algos.h" | |||
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | |||
#include "src/cuda/utils.h" | |||
#if CUDA_VERSION >= 9020 | |||
using namespace megdnn; | |||
using namespace cuda; | |||
std::string | |||
MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::AlgoParam::to_string() const { | |||
return ssprintf("%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n, | |||
threadblock_k, warp_m, warp_n, warp_k); | |||
} | |||
std::pair<bool, TensorLayoutArray> | |||
MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::construct_aligned_layouts( | |||
const SizeArgs& args) const { | |||
int alignment = max_alignment(args); | |||
int min_alignment = min_alignment_requirement(); | |||
bool aligned = alignment >= min_alignment; | |||
if (aligned) | |||
return std::make_pair(!aligned, TensorLayoutArray{{}}); | |||
auto&& param = args.opr->param(); | |||
int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | |||
k = args.layout_a.shape[param.transposeA ? 0 : 1]; | |||
size_t align_m = get_aligned_power2(m, min_alignment); | |||
size_t align_n = get_aligned_power2(n, min_alignment); | |||
size_t align_k = get_aligned_power2(k, min_alignment); | |||
TensorLayoutArray layouts; | |||
layouts.emplace_back(TensorLayout{{align_m, align_k}, args.layout_a.dtype}); | |||
layouts.emplace_back(TensorLayout{{align_k, align_n}, args.layout_b.dtype}); | |||
layouts.emplace_back(TensorLayout{{align_m, align_n}, args.layout_c.dtype}); | |||
return std::make_pair(!aligned, std::move(layouts)); | |||
} | |||
void MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::exec( | |||
const ExecArgs& args) const { | |||
auto aligned = construct_aligned_layouts(args); | |||
if (!aligned.first) | |||
return do_exec(args); | |||
const auto& layouts = aligned.second; | |||
auto tensor_a = args.tensor_a; | |||
auto tensor_b = args.tensor_b; | |||
auto workspace = args.workspace; | |||
size_t copy_size = 0; | |||
for (const auto& ly : layouts) | |||
copy_size += ly.span().dist_byte(); | |||
auto&& param = args.opr->param(); | |||
auto&& stream = cuda_stream(args.opr->handle()); | |||
cuda_check(cudaMemsetAsync(workspace.raw_ptr, 0, copy_size, stream)); | |||
auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>(); | |||
auto copy_stride = [](const TensorLayout& src, TensorLayout& dst, | |||
bool trans) { | |||
dst.stride[0] = src.stride[0], dst.stride[1] = src.stride[1]; | |||
if (trans) | |||
std::swap(dst.stride[0], dst.stride[1]); | |||
}; | |||
copy_stride(layouts[0], tensor_a.layout, param.transposeA); | |||
tensor_a.raw_ptr = workspace.raw_ptr; | |||
relayout->exec(args.tensor_a, tensor_a); | |||
workspace.raw_ptr += layouts[0].span().dist_byte(); | |||
workspace.size -= layouts[0].span().dist_byte(); | |||
copy_stride(layouts[1], tensor_b.layout, param.transposeB); | |||
tensor_b.raw_ptr = workspace.raw_ptr; | |||
relayout->exec(args.tensor_b, tensor_b); | |||
workspace.raw_ptr += layouts[1].span().dist_byte(); | |||
workspace.size -= layouts[1].span().dist_byte(); | |||
decltype(tensor_a) tensor_c{workspace.raw_ptr, layouts[2]}; | |||
workspace.raw_ptr += layouts[2].span().dist_byte(); | |||
workspace.size -= layouts[2].span().dist_byte(); | |||
auto&& matmul = args.opr->handle()->create_operator<MatrixMulForward>(); | |||
matmul->param().transposeA = false; | |||
matmul->param().transposeB = false; | |||
matmul->param().compute_mode = args.opr->param().compute_mode; | |||
tensor_a.layout = layouts[0]; | |||
tensor_b.layout = layouts[1]; | |||
ExecArgs args_{static_cast<MatrixMulForwardImpl*>(matmul.get()), tensor_a, | |||
tensor_b, tensor_c, workspace}; | |||
do_exec(args_); | |||
tensor_c.layout.TensorShape::operator=(args.layout_c); | |||
relayout->exec(tensor_c, args.tensor_c); | |||
} | |||
int MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::max_alignment( | |||
const SizeArgs& args) const { | |||
auto&& dtype_a = args.layout_a.dtype; | |||
auto&& dtype_b = args.layout_b.dtype; | |||
auto&& dtype_c = args.layout_c.dtype; | |||
auto get_alignment = [](const DType& dt, int len) { | |||
int size_bits = dt.size(1) * 8; | |||
int align = 128; | |||
while (align > 1) { | |||
if ((len * size_bits) % align == 0) | |||
break; | |||
align = align / 2; | |||
} | |||
return align / size_bits; | |||
}; | |||
int lda = args.layout_a.stride[0], ldb = args.layout_b.stride[0], | |||
ldc = args.layout_c.stride[0]; | |||
auto&& param = args.opr->param(); | |||
int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | |||
k = args.layout_a.shape[param.transposeA ? 0 : 1]; | |||
int max_align = get_alignment(dtype_a, lda); | |||
max_align = std::min(get_alignment(dtype_a, m), max_align); | |||
max_align = std::min(get_alignment(dtype_a, n), max_align); | |||
max_align = std::min(get_alignment(dtype_a, k), max_align); | |||
max_align = std::min(get_alignment(dtype_a, lda), max_align); | |||
max_align = std::min(get_alignment(dtype_b, ldb), max_align); | |||
max_align = std::min(get_alignment(dtype_c, ldc), max_align); | |||
return max_align; | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -42,9 +42,12 @@ public: | |||
class AlgoBFloat16; | |||
#endif | |||
#if CUDA_VERSION >= 9020 | |||
class AlgoCutlassMatrixMulBase; | |||
class AlgoFloat32SIMT; | |||
class AlgoFloat32SIMTSplitK; | |||
class AlgoFloat32SIMTGemvBatchedStrided; | |||
class AlgoFloat16TensorOp; | |||
class AlgoFloat16TensorOpSplitK; | |||
#endif | |||
class AlgoPack; | |||
@@ -184,7 +184,8 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | |||
const ExecutionPolicyAlgoName& algo, | |||
param::MatrixMul::Format format, size_t nbase, | |||
float eps, std::vector<TestArg>&& user_args, | |||
bool force_deduce_dst) { | |||
bool force_deduce_dst, | |||
param::MatrixMul::ComputeMode compute_mode) { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv()); | |||
Checker<Opr> checker(handle); | |||
checker.set_force_deduce_dst(force_deduce_dst); | |||
@@ -261,6 +262,7 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | |||
Param param; | |||
param.transposeA = arg.mask & 0x1; | |||
param.transposeB = arg.mask & 0x2; | |||
param.compute_mode = compute_mode; | |||
param.format = format; | |||
checker.set_dtype(0, A_dtype) | |||
.set_dtype(1, B_dtype) | |||
@@ -69,7 +69,9 @@ void check_matrix_mul( | |||
const ExecutionPolicyAlgoName& algo = {"", {}}, | |||
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, | |||
size_t nbase = 8, float eps = 1e-3, std::vector<TestArg>&& args = {}, | |||
bool force_deduce_dst = true); | |||
bool force_deduce_dst = true, | |||
param::MatrixMul::ComputeMode compute_mode = | |||
param::MatrixMul::ComputeMode::DEFAULT); | |||
void check_matrix_mul( | |||
DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, | |||
@@ -21,6 +21,7 @@ | |||
#include "test/cuda/fixture.h" | |||
#include "test/cuda/utils.h" | |||
#define MEGDNN_WITH_BENCHMARK 1 | |||
#if CUDA_VERSION >= 9020 | |||
namespace megdnn { | |||
namespace test { | |||
@@ -215,6 +216,14 @@ std::vector<BenchArgs> get_feat_model_args() { | |||
return args; | |||
} | |||
std::vector<BenchArgs> get_f16_feat_model_args() { | |||
std::vector<BenchArgs> args; | |||
args.emplace_back(BenchArgs{128, 9216, 9216}); | |||
args.emplace_back(BenchArgs{128, 6400, 6400}); | |||
args.emplace_back(BenchArgs{128, 5184, 5184}); | |||
return args; | |||
} | |||
void benchmark_matrix_mul( | |||
Handle* handle, const std::vector<BenchArgs>& args, DType A_dtype, | |||
DType B_dtype, DType C_dtype, const char* algo = nullptr, | |||
@@ -364,6 +373,82 @@ MEGDNN_FOREACH_CUTLASS_KERNEL(cb) | |||
#undef cb | |||
#undef MEGDNN_FOREACH_CUTLASS_KERNEL | |||
#define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \ | |||
cb(1, 256, 128, 32, 64, 64, 32, 8, 8, 4); \ | |||
cb(2, 128, 256, 32, 64, 64, 32, 8, 8, 4); \ | |||
cb(3, 128, 128, 32, 64, 64, 32, 8, 8, 4); | |||
#define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \ | |||
TEST_F(CUDA, CUTLASS_F16_884_GEMM_##name) { \ | |||
require_compute_capability(7, 0); \ | |||
matrix_mul::check_matrix_mul<MatrixMulForward>( \ | |||
dtype::Float16(), dtype::Float16(), dtype::Float16(), \ | |||
handle_cuda(), \ | |||
"CUTLASS_FLOAT16_TENSOR_OP_h" #im #in #ik "_" #tbm "X" #tbn \ | |||
"X" #tbk "_" #wm "X" #wn "X" #wk, \ | |||
param::MatrixMul::Format::DEFAULT, 8, 1e-2, \ | |||
matrix_mul::get_matmul_args()); \ | |||
} | |||
MEGDNN_FOREACH_CUTLASS_KERNEL(cb) | |||
#undef cb | |||
#define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \ | |||
TEST_F(CUDA, CUTLASS_F16_884_GEMM_SPLIT_K_##name) { \ | |||
require_compute_capability(7, 0); \ | |||
matrix_mul::check_matrix_mul<MatrixMulForward>( \ | |||
dtype::Float16(), dtype::Float16(), dtype::Float16(), \ | |||
handle_cuda(), \ | |||
"CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h" #im #in #ik "_" #tbm \ | |||
"X" #tbn "X" #tbk "_" #wm "X" #wn "X" #wk, \ | |||
param::MatrixMul::Format::DEFAULT, 8, 1e-3, \ | |||
matrix_mul::get_matmul_args_split_k(), true, \ | |||
param::MatrixMul::ComputeMode::FLOAT32); \ | |||
} | |||
MEGDNN_FOREACH_CUTLASS_KERNEL(cb) | |||
#undef cb | |||
#undef MEGDNN_FOREACH_CUTLASS_KERNEL | |||
#define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \ | |||
cb(1, 256, 128, 32, 64, 64, 32, 16, 8, 8); \ | |||
cb(2, 128, 256, 32, 64, 64, 32, 16, 8, 8); \ | |||
cb(3, 128, 128, 32, 64, 64, 32, 16, 8, 8); | |||
#define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \ | |||
TEST_F(CUDA, CUTLASS_F16_1688_GEMM_##name) { \ | |||
require_compute_capability(7, 5); \ | |||
matrix_mul::check_matrix_mul<MatrixMulForward>( \ | |||
dtype::Float16(), dtype::Float16(), dtype::Float16(), \ | |||
handle_cuda(), \ | |||
"CUTLASS_FLOAT16_TENSOR_OP_h" #im #in #ik "_" #tbm "X" #tbn \ | |||
"X" #tbk "_" #wm "X" #wn "X" #wk, \ | |||
param::MatrixMul::Format::DEFAULT, 8, 1e-2, \ | |||
matrix_mul::get_matmul_args(), true, \ | |||
param::MatrixMul::ComputeMode::FLOAT32); \ | |||
} | |||
MEGDNN_FOREACH_CUTLASS_KERNEL(cb) | |||
#undef cb | |||
#define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \ | |||
TEST_F(CUDA, CUTLASS_F16_1688_GEMM_SPLIT_K_##name) { \ | |||
require_compute_capability(7, 5); \ | |||
matrix_mul::check_matrix_mul<MatrixMulForward>( \ | |||
dtype::Float16(), dtype::Float16(), dtype::Float16(), \ | |||
handle_cuda(), \ | |||
"CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h" #im #in #ik "_" #tbm \ | |||
"X" #tbn "X" #tbk "_" #wm "X" #wn "X" #wk, \ | |||
param::MatrixMul::Format::DEFAULT, 8, 1e-3, \ | |||
matrix_mul::get_matmul_args_split_k()); \ | |||
} | |||
MEGDNN_FOREACH_CUTLASS_KERNEL(cb) | |||
#undef cb | |||
#undef MEGDNN_FOREACH_CUTLASS_KERNEL | |||
#if MEGDNN_WITH_BENCHMARK | |||
TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL) { | |||
benchmark_matrix_mul(handle_cuda(), get_square_matmul_args(), | |||
@@ -376,6 +461,12 @@ TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL_FEAT) { | |||
dtype::Float32(), dtype::Float32(), | |||
"CUTLASS_FLOAT32_SIMT"); | |||
} | |||
TEST_F(CUDA, BENCHMARK_CUTLASS_F16_MATMUL_FEAT) { | |||
benchmark_matrix_mul(handle_cuda(), get_f16_feat_model_args(), | |||
dtype::Float16(), dtype::Float16(), dtype::Float16(), | |||
"CUTLASS_FLOAT16_TENSOR_OP"); | |||
} | |||
#endif | |||
} // namespace test | |||
} // namespace megdnn | |||