GitOrigin-RevId: 025c591f75
tags/v1.6.0-rc1
@@ -5,6 +5,8 @@ genrule( | |||||
outs = cutlass_gen_list, | outs = cutlass_gen_list, | ||||
cmd = """GEN=$(location //brain/megbrain/dnn/scripts/cutlass_generator:generator.py) | cmd = """GEN=$(location //brain/megbrain/dnn/scripts/cutlass_generator:generator.py) | ||||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type simt $(@D) | CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type simt $(@D) | ||||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type tensorop884 $(@D) | |||||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type tensorop1688 $(@D) | |||||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemv --type simt $(@D) | CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemv --type simt $(@D) | ||||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type simt $(@D) | CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type simt $(@D) | ||||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type simt $(@D) | CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type simt $(@D) | ||||
@@ -252,7 +252,8 @@ def GeneratesGemm(tile, data_type, layout_a, layout_b, layout_c, min_cc, align_a | |||||
if tile.math_instruction.element_accumulator == DataType.s32: | if tile.math_instruction.element_accumulator == DataType.s32: | ||||
epilogues = [EpilogueFunctor.LinearCombinationClamp] | epilogues = [EpilogueFunctor.LinearCombinationClamp] | ||||
else: | else: | ||||
assert tile.math_instruction.element_accumulator == DataType.f32 | |||||
assert tile.math_instruction.element_accumulator == DataType.f32 or \ | |||||
tile.math_instruction.element_accumulator == DataType.f16 | |||||
epilogues = [EpilogueFunctor.LinearCombination] | epilogues = [EpilogueFunctor.LinearCombination] | ||||
for epilogue in epilogues: | for epilogue in epilogues: | ||||
@@ -799,7 +800,22 @@ class EmitGemmSplitKParallelInstance: | |||||
${epilogue_vector_length}, | ${epilogue_vector_length}, | ||||
${element_accumulator}, | ${element_accumulator}, | ||||
${element_epilogue} | ${element_epilogue} | ||||
> | |||||
>, | |||||
cutlass::epilogue::thread::Convert< | |||||
${element_accumulator}, | |||||
${epilogue_vector_length}, | |||||
${element_accumulator} | |||||
>, | |||||
cutlass::reduction::thread::ReduceAdd< | |||||
${element_accumulator}, | |||||
${element_accumulator}, | |||||
${epilogue_vector_length} | |||||
>, | |||||
cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle, | |||||
${stages}, | |||||
${align_a}, | |||||
${align_b}, | |||||
${math_operation} | |||||
>; | >; | ||||
""" | """ | ||||
def emit(self, operation): | def emit(self, operation): | ||||
@@ -831,7 +847,10 @@ class EmitGemmSplitKParallelInstance: | |||||
'epilogue_vector_length': str(epilogue_vector_length), | 'epilogue_vector_length': str(epilogue_vector_length), | ||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]), | 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), | ||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], | 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], | ||||
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], | |||||
'stages': str(operation.tile_description.stages), | |||||
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], | |||||
'align_a': str(operation.A.alignment), | |||||
'align_b': str(operation.B.alignment), | |||||
} | } | ||||
return SubstituteTemplate(self.template, values) | return SubstituteTemplate(self.template, values) | ||||
@@ -32,6 +32,8 @@ if __name__ == "__main__": | |||||
f.write("# Generated by dnn/scripts/cutlass_generator/gen_list.py\n\n") | f.write("# Generated by dnn/scripts/cutlass_generator/gen_list.py\n\n") | ||||
f.write("cutlass_gen_list = [\n") | f.write("cutlass_gen_list = [\n") | ||||
write_op_list(f, "gemm", "simt") | write_op_list(f, "gemm", "simt") | ||||
write_op_list(f, "gemm", "tensorop1688") | |||||
write_op_list(f, "gemm", "tensorop884") | |||||
write_op_list(f, "gemv", "simt") | write_op_list(f, "gemv", "simt") | ||||
write_op_list(f, "deconv", "simt") | write_op_list(f, "deconv", "simt") | ||||
write_op_list(f, "conv2d", "simt") | write_op_list(f, "conv2d", "simt") | ||||
@@ -597,6 +597,131 @@ def GenerateGemv_Simt(args): | |||||
return operations | return operations | ||||
# | # | ||||
def GeneratesGemm_TensorOp_1688(args): | |||||
layouts = [ | |||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn | |||||
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt | |||||
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn | |||||
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt | |||||
] | |||||
math_instructions = [ | |||||
MathInstruction( \ | |||||
[16, 8, 8], \ | |||||
DataType.f16, DataType.f16, DataType.f32, \ | |||||
OpcodeClass.TensorOp, \ | |||||
MathOperation.multiply_add), | |||||
MathInstruction( \ | |||||
[16, 8, 8], \ | |||||
DataType.f16, DataType.f16, DataType.f16, \ | |||||
OpcodeClass.TensorOp, \ | |||||
MathOperation.multiply_add), | |||||
] | |||||
min_cc = 75 | |||||
max_cc = 1024 | |||||
alignment_constraints = [8, 4, 2, | |||||
#1 | |||||
] | |||||
operations = [] | |||||
for math_inst in math_instructions: | |||||
for layout in layouts: | |||||
for align in alignment_constraints: | |||||
tile_descriptions = [ | |||||
TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
## comment some configuration to reduce compilation time and binary size | |||||
# TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
# TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
# TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
] | |||||
data_type = [ | |||||
math_inst.element_a, | |||||
math_inst.element_b, | |||||
math_inst.element_a, | |||||
math_inst.element_accumulator, | |||||
] | |||||
for tile in tile_descriptions: | |||||
operations += GeneratesGemm(tile, \ | |||||
data_type, \ | |||||
layout[0], \ | |||||
layout[1], \ | |||||
layout[2], \ | |||||
min_cc, \ | |||||
align * 16, \ | |||||
align * 16, \ | |||||
align * 16) | |||||
return operations | |||||
# | |||||
def GeneratesGemm_TensorOp_884(args): | |||||
layouts = [ | |||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn | |||||
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt | |||||
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn | |||||
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt | |||||
] | |||||
math_instructions = [ | |||||
MathInstruction( \ | |||||
[8, 8, 4], \ | |||||
DataType.f16, DataType.f16, DataType.f32, \ | |||||
OpcodeClass.TensorOp, \ | |||||
MathOperation.multiply_add), | |||||
MathInstruction( \ | |||||
[8, 8, 4], \ | |||||
DataType.f16, DataType.f16, DataType.f16, \ | |||||
OpcodeClass.TensorOp, \ | |||||
MathOperation.multiply_add), | |||||
] | |||||
min_cc = 70 | |||||
max_cc = 75 | |||||
alignment_constraints = [8, 4, 2, | |||||
# 1 | |||||
] | |||||
operations = [] | |||||
for math_inst in math_instructions: | |||||
for layout in layouts: | |||||
for align in alignment_constraints: | |||||
tile_descriptions = [ | |||||
TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
## comment some configuration to reduce compilation time and binary size | |||||
# TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
# TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
# TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
] | |||||
data_type = [ | |||||
math_inst.element_a, | |||||
math_inst.element_b, | |||||
math_inst.element_a, | |||||
math_inst.element_accumulator, | |||||
] | |||||
for tile in tile_descriptions: | |||||
operations += GeneratesGemm(tile, \ | |||||
data_type, \ | |||||
layout[0], \ | |||||
layout[1], \ | |||||
layout[2], \ | |||||
min_cc, \ | |||||
align * 16, \ | |||||
align * 16, \ | |||||
align * 16) | |||||
return operations | |||||
# | |||||
def GenerateConv2dOperations(args): | def GenerateConv2dOperations(args): | ||||
if args.type == "simt": | if args.type == "simt": | ||||
return GenerateConv2d_Simt(args) | return GenerateConv2d_Simt(args) | ||||
@@ -613,9 +738,14 @@ def GenerateDeconvOperations(args): | |||||
return GenerateDeconv_Simt(args) | return GenerateDeconv_Simt(args) | ||||
def GenerateGemmOperations(args): | def GenerateGemmOperations(args): | ||||
assert args.type == "simt", "operation gemm only support" \ | |||||
"simt. (got:{})".format(args.type) | |||||
return GenerateGemm_Simt(args) | |||||
if args.type == "tensorop884": | |||||
return GeneratesGemm_TensorOp_884(args) | |||||
elif args.type == "tensorop1688": | |||||
return GeneratesGemm_TensorOp_1688(args) | |||||
else: | |||||
assert args.type == "simt", "operation gemm only support" \ | |||||
"simt. (got:{})".format(args.type) | |||||
return GenerateGemm_Simt(args) | |||||
def GenerateGemvOperations(args): | def GenerateGemvOperations(args): | ||||
assert args.type == "simt", "operation gemv only support" \ | assert args.type == "simt", "operation gemv only support" \ | ||||
@@ -631,7 +761,7 @@ if __name__ == "__main__": | |||||
parser.add_argument("--operations", type=str, choices=['gemm', 'gemv', 'conv2d', 'deconv'], | parser.add_argument("--operations", type=str, choices=['gemm', 'gemv', 'conv2d', 'deconv'], | ||||
required=True, help="Specifies the operation to generate (gemm, gemv, conv2d, deconv)") | required=True, help="Specifies the operation to generate (gemm, gemv, conv2d, deconv)") | ||||
parser.add_argument("output", type=str, help="output directory for CUTLASS kernel files") | parser.add_argument("output", type=str, help="output directory for CUTLASS kernel files") | ||||
parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832'], | |||||
parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832', 'tensorop884', 'tensorop1688'], | |||||
default='simt', help="kernel type of CUTLASS kernel generator") | default='simt', help="kernel type of CUTLASS kernel generator") | ||||
gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl" | gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl" | ||||
@@ -138,6 +138,296 @@ cutlass_gen_list = [ | |||||
"cutlass_simt_sgemm_256x64_8x2_tt_align1.cu", | "cutlass_simt_sgemm_256x64_8x2_tt_align1.cu", | ||||
"cutlass_simt_sgemm_split_k_parallel_256x64_8x2_tt_align1.cu", | "cutlass_simt_sgemm_split_k_parallel_256x64_8x2_tt_align1.cu", | ||||
"all_gemm_simt_operations.cu", | "all_gemm_simt_operations.cu", | ||||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_256x128_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_128x256_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_128x128_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_256x128_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_128x256_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_128x128_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_256x128_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_128x256_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_128x128_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_256x128_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_128x256_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_128x128_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_256x128_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_128x256_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_128x128_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_256x128_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_128x256_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_128x128_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_256x128_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_128x256_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_128x128_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_256x128_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_128x256_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_128x128_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_256x128_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_128x256_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_128x128_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_256x128_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_128x256_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_128x128_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_h1688gemm_256x128_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_128x256_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_128x128_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_h1688gemm_256x128_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_128x256_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_128x128_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tt_align2.cu", | |||||
"all_gemm_tensorop1688_operations.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_h884gemm_256x128_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_h884gemm_128x256_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_h884gemm_128x128_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nn_align8.cu", | |||||
"cutlass_tensorop_h884gemm_256x128_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_h884gemm_128x256_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_h884gemm_128x128_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nn_align4.cu", | |||||
"cutlass_tensorop_h884gemm_256x128_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_h884gemm_128x256_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_h884gemm_128x128_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nn_align2.cu", | |||||
"cutlass_tensorop_h884gemm_256x128_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_h884gemm_128x256_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_h884gemm_128x128_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nt_align8.cu", | |||||
"cutlass_tensorop_h884gemm_256x128_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_h884gemm_128x256_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_h884gemm_128x128_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nt_align4.cu", | |||||
"cutlass_tensorop_h884gemm_256x128_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_h884gemm_128x256_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_h884gemm_128x128_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nt_align2.cu", | |||||
"cutlass_tensorop_h884gemm_256x128_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_h884gemm_128x256_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_h884gemm_128x128_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tn_align8.cu", | |||||
"cutlass_tensorop_h884gemm_256x128_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_h884gemm_128x256_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_h884gemm_128x128_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tn_align4.cu", | |||||
"cutlass_tensorop_h884gemm_256x128_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_h884gemm_128x256_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_h884gemm_128x128_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tn_align2.cu", | |||||
"cutlass_tensorop_h884gemm_256x128_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_h884gemm_128x256_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_h884gemm_128x128_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tt_align8.cu", | |||||
"cutlass_tensorop_h884gemm_256x128_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_h884gemm_128x256_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_h884gemm_128x128_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tt_align4.cu", | |||||
"cutlass_tensorop_h884gemm_256x128_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_h884gemm_128x256_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_h884gemm_128x128_32x2_tt_align2.cu", | |||||
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tt_align2.cu", | |||||
"all_gemm_tensorop884_operations.cu", | |||||
"cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4.cu", | "cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4.cu", | ||||
"cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2.cu", | "cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2.cu", | ||||
"cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1.cu", | "cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1.cu", | ||||
@@ -646,4 +936,4 @@ cutlass_gen_list = [ | |||||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | ||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | ||||
"all_conv2d_tensorop8832_operations.cu", | "all_conv2d_tensorop8832_operations.cu", | ||||
] | |||||
] |
@@ -151,6 +151,8 @@ if(MGE_WITH_CUDA) | |||||
set(${gen_files} "${${gen_files}}" PARENT_SCOPE) | set(${gen_files} "${${gen_files}}" PARENT_SCOPE) | ||||
endfunction() | endfunction() | ||||
gen_cutlass_kimpl(gemm simt CUTLASS_SOURCES) | gen_cutlass_kimpl(gemm simt CUTLASS_SOURCES) | ||||
gen_cutlass_kimpl(gemm tensorop884 CUTLASS_SOURCES) | |||||
gen_cutlass_kimpl(gemm tensorop1688 CUTLASS_SOURCES) | |||||
gen_cutlass_kimpl(gemv simt CUTLASS_SOURCES) | gen_cutlass_kimpl(gemv simt CUTLASS_SOURCES) | ||||
gen_cutlass_kimpl(deconv simt CUTLASS_SOURCES) | gen_cutlass_kimpl(deconv simt CUTLASS_SOURCES) | ||||
gen_cutlass_kimpl(conv2d simt CUTLASS_SOURCES) | gen_cutlass_kimpl(conv2d simt CUTLASS_SOURCES) | ||||
@@ -49,6 +49,8 @@ namespace library { | |||||
(__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | ||||
void initialize_all_gemm_simt_operations(Manifest& manifest); | void initialize_all_gemm_simt_operations(Manifest& manifest); | ||||
void initialize_all_gemm_tensorop884_operations(Manifest& manifest); | |||||
void initialize_all_gemm_tensorop1688_operations(Manifest& manifest); | |||||
void initialize_all_conv2d_simt_operations(Manifest& manifest); | void initialize_all_conv2d_simt_operations(Manifest& manifest); | ||||
void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest); | void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest); | ||||
void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest); | void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest); | ||||
@@ -56,6 +58,8 @@ void initialize_all_deconv_simt_operations(Manifest& manifest); | |||||
void initialize_all(Manifest& manifest) { | void initialize_all(Manifest& manifest) { | ||||
initialize_all_gemm_simt_operations(manifest); | initialize_all_gemm_simt_operations(manifest); | ||||
initialize_all_gemm_tensorop884_operations(manifest); | |||||
initialize_all_gemm_tensorop1688_operations(manifest); | |||||
initialize_all_conv2d_simt_operations(manifest); | initialize_all_conv2d_simt_operations(manifest); | ||||
initialize_all_conv2d_tensorop8816_operations(manifest); | initialize_all_conv2d_tensorop8816_operations(manifest); | ||||
initialize_all_conv2d_tensorop8832_operations(manifest); | initialize_all_conv2d_tensorop8832_operations(manifest); | ||||
@@ -55,6 +55,8 @@ GemmKey get_gemm_key_from_desc(const GemmDescription& desc) { | |||||
key.layout_B = desc.B.layout; | key.layout_B = desc.B.layout; | ||||
key.element_C = desc.C.element; | key.element_C = desc.C.element; | ||||
key.layout_C = desc.C.layout; | key.layout_C = desc.C.layout; | ||||
key.element_accumulator = | |||||
desc.tile_description.math_instruction.element_accumulator; | |||||
key.threadblock_shape_m = desc.tile_description.threadblock_shape.m(); | key.threadblock_shape_m = desc.tile_description.threadblock_shape.m(); | ||||
key.threadblock_shape_n = desc.tile_description.threadblock_shape.n(); | key.threadblock_shape_n = desc.tile_description.threadblock_shape.n(); | ||||
@@ -75,6 +77,8 @@ GemmKey get_gemm_key_from_desc(const GemmDescription& desc) { | |||||
desc.tile_description.math_instruction.instruction_shape.k(); | desc.tile_description.math_instruction.instruction_shape.k(); | ||||
key.stages = desc.stages; | key.stages = desc.stages; | ||||
key.alignment_A = desc.A.alignment; | |||||
key.alignment_B = desc.B.alignment; | |||||
key.split_k_mode = desc.split_k_mode; | key.split_k_mode = desc.split_k_mode; | ||||
return key; | return key; | ||||
@@ -77,6 +77,7 @@ struct GemmKey { | |||||
LayoutTypeID layout_B; | LayoutTypeID layout_B; | ||||
NumericTypeID element_C; | NumericTypeID element_C; | ||||
LayoutTypeID layout_C; | LayoutTypeID layout_C; | ||||
NumericTypeID element_accumulator; | |||||
int threadblock_shape_m; | int threadblock_shape_m; | ||||
int threadblock_shape_n; | int threadblock_shape_n; | ||||
@@ -91,12 +92,15 @@ struct GemmKey { | |||||
int instruction_shape_k; | int instruction_shape_k; | ||||
int stages; | int stages; | ||||
int alignment_A; | |||||
int alignment_B; | |||||
SplitKMode split_k_mode; | SplitKMode split_k_mode; | ||||
inline bool operator==(GemmKey const& rhs) const { | inline bool operator==(GemmKey const& rhs) const { | ||||
return (element_A == rhs.element_A) && (layout_A == rhs.layout_A) && | return (element_A == rhs.element_A) && (layout_A == rhs.layout_A) && | ||||
(element_B == rhs.element_B) && (layout_B == rhs.layout_B) && | (element_B == rhs.element_B) && (layout_B == rhs.layout_B) && | ||||
(element_C == rhs.element_C) && (layout_C == rhs.layout_C) && | (element_C == rhs.element_C) && (layout_C == rhs.layout_C) && | ||||
(element_accumulator == rhs.element_accumulator) && | |||||
(threadblock_shape_m == rhs.threadblock_shape_m) && | (threadblock_shape_m == rhs.threadblock_shape_m) && | ||||
(threadblock_shape_n == rhs.threadblock_shape_n) && | (threadblock_shape_n == rhs.threadblock_shape_n) && | ||||
(threadblock_shape_k == rhs.threadblock_shape_k) && | (threadblock_shape_k == rhs.threadblock_shape_k) && | ||||
@@ -106,7 +110,9 @@ struct GemmKey { | |||||
(instruction_shape_m == rhs.instruction_shape_m) && | (instruction_shape_m == rhs.instruction_shape_m) && | ||||
(instruction_shape_n == rhs.instruction_shape_n) && | (instruction_shape_n == rhs.instruction_shape_n) && | ||||
(instruction_shape_k == rhs.instruction_shape_k) && | (instruction_shape_k == rhs.instruction_shape_k) && | ||||
(stages == rhs.stages) && (split_k_mode == rhs.split_k_mode); | |||||
(stages == rhs.stages) && (alignment_A == rhs.alignment_A) && | |||||
(alignment_B == rhs.alignment_B) && | |||||
(split_k_mode == rhs.split_k_mode); | |||||
} | } | ||||
inline bool operator!=(GemmKey const& rhs) const { return !(*this == rhs); } | inline bool operator!=(GemmKey const& rhs) const { return !(*this == rhs); } | ||||
@@ -130,10 +136,13 @@ struct GemmKey { | |||||
"\n layout_B: " + to_string(layout_B) + | "\n layout_B: " + to_string(layout_B) + | ||||
"\n element_C: " + to_string(element_C) + | "\n element_C: " + to_string(element_C) + | ||||
"\n layout_C: " + to_string(layout_C) + | "\n layout_C: " + to_string(layout_C) + | ||||
"\n element_accumulator: " + to_string(element_accumulator) + | |||||
"\n threadblock_shape: " + threadblock_shape_str + | "\n threadblock_shape: " + threadblock_shape_str + | ||||
"\n warp_shape: " + warp_shape_str + | "\n warp_shape: " + warp_shape_str + | ||||
"\n instruction_shape: " + instruction_shape_str + | "\n instruction_shape: " + instruction_shape_str + | ||||
"\n stages: " + std::to_string(stages) + | "\n stages: " + std::to_string(stages) + | ||||
"\n alignment_A: " + std::to_string(alignment_A) + | |||||
"\n alignment_B: " + std::to_string(alignment_B) + | |||||
"\n split_k_mode: " + to_string(split_k_mode) + "\n}"; | "\n split_k_mode: " + to_string(split_k_mode) + "\n}"; | ||||
} | } | ||||
}; | }; | ||||
@@ -147,6 +156,8 @@ struct GemmKeyHasher { | |||||
.update(&key.layout_B, sizeof(key.layout_B)) | .update(&key.layout_B, sizeof(key.layout_B)) | ||||
.update(&key.element_C, sizeof(key.element_C)) | .update(&key.element_C, sizeof(key.element_C)) | ||||
.update(&key.layout_C, sizeof(key.layout_C)) | .update(&key.layout_C, sizeof(key.layout_C)) | ||||
.update(&key.element_accumulator, | |||||
sizeof(key.element_accumulator)) | |||||
.update(&key.threadblock_shape_m, | .update(&key.threadblock_shape_m, | ||||
sizeof(key.threadblock_shape_m)) | sizeof(key.threadblock_shape_m)) | ||||
.update(&key.threadblock_shape_n, | .update(&key.threadblock_shape_n, | ||||
@@ -157,6 +168,8 @@ struct GemmKeyHasher { | |||||
.update(&key.warp_shape_n, sizeof(key.warp_shape_n)) | .update(&key.warp_shape_n, sizeof(key.warp_shape_n)) | ||||
.update(&key.warp_shape_k, sizeof(key.warp_shape_k)) | .update(&key.warp_shape_k, sizeof(key.warp_shape_k)) | ||||
.update(&key.stages, sizeof(key.stages)) | .update(&key.stages, sizeof(key.stages)) | ||||
.update(&key.alignment_A, sizeof(key.alignment_A)) | |||||
.update(&key.alignment_B, sizeof(key.alignment_B)) | |||||
.update(&key.split_k_mode, sizeof(key.split_k_mode)) | .update(&key.split_k_mode, sizeof(key.split_k_mode)) | ||||
.digest(); | .digest(); | ||||
} | } | ||||
@@ -43,6 +43,12 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||||
for (auto&& algo : simt_float32_gemv_batched_strided) { | for (auto&& algo : simt_float32_gemv_batched_strided) { | ||||
all_algos.push_back(&algo); | all_algos.push_back(&algo); | ||||
} | } | ||||
for (auto&& algo : tensorop_float16) { | |||||
all_algos.push_back(&algo); | |||||
} | |||||
for (auto&& algo : tensorop_float16_split_k) { | |||||
all_algos.push_back(&algo); | |||||
} | |||||
#endif | #endif | ||||
all_algos.push_back(&naive); | all_algos.push_back(&naive); | ||||
@@ -53,7 +59,7 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||||
#if CUDA_VERSION >= 9020 | #if CUDA_VERSION >= 9020 | ||||
void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() { | void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() { | ||||
using AlgoParam = AlgoFloat32SIMT::AlgoParam; | |||||
using AlgoParam = AlgoCutlassMatrixMulBase::AlgoParam; | |||||
simt_float32.emplace_back(AlgoParam{64, 256, 8, 32, 64, 8}); | simt_float32.emplace_back(AlgoParam{64, 256, 8, 32, 64, 8}); | ||||
simt_float32.emplace_back(AlgoParam{256, 64, 8, 64, 32, 8}); | simt_float32.emplace_back(AlgoParam{256, 64, 8, 64, 32, 8}); | ||||
simt_float32.emplace_back(AlgoParam{32, 256, 8, 16, 64, 8}); | simt_float32.emplace_back(AlgoParam{32, 256, 8, 16, 64, 8}); | ||||
@@ -91,6 +97,19 @@ void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() { | |||||
simt_float32_gemv_batched_strided.emplace_back(128); | simt_float32_gemv_batched_strided.emplace_back(128); | ||||
simt_float32_gemv_batched_strided.emplace_back(64); | simt_float32_gemv_batched_strided.emplace_back(64); | ||||
simt_float32_gemv_batched_strided.emplace_back(32); | simt_float32_gemv_batched_strided.emplace_back(32); | ||||
#define FOREACH_CUTLASS_MATMUL_F16_SHAPES(cb) \ | |||||
cb(256, 128, 32, 64, 64, 32, 8, 8, 4); \ | |||||
cb(128, 256, 32, 64, 64, 32, 8, 8, 4); \ | |||||
cb(128, 128, 32, 64, 64, 32, 8, 8, 4); \ | |||||
cb(256, 128, 32, 64, 64, 32, 16, 8, 8); \ | |||||
cb(128, 256, 32, 64, 64, 32, 16, 8, 8); \ | |||||
cb(128, 128, 32, 64, 64, 32, 16, 8, 8); | |||||
#define cb(...) \ | |||||
tensorop_float16.emplace_back(AlgoParam{__VA_ARGS__}); \ | |||||
tensorop_float16_split_k.emplace_back(AlgoParam{__VA_ARGS__}); | |||||
FOREACH_CUTLASS_MATMUL_F16_SHAPES(cb) | |||||
#undef cb | |||||
#undef FOREACH_CUTLASS_MATMUL_F16_SHAPES | |||||
} | } | ||||
#endif | #endif | ||||
@@ -41,11 +41,13 @@ public: | |||||
CUDA_WMMA_UINT4X4X32, | CUDA_WMMA_UINT4X4X32, | ||||
CUDA_CUBLASLT, | CUDA_CUBLASLT, | ||||
CUDA_NAIVE, | CUDA_NAIVE, | ||||
CUDA_BFLOAT16, | |||||
CUDA_BFLOAT16, | |||||
#if CUDA_VERSION >= 9020 | #if CUDA_VERSION >= 9020 | ||||
CUDA_FLOAT32_SIMT, | |||||
CUDA_FLOAT32_SIMT_SPLIT_K, | |||||
CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED, | |||||
CUDA_FLOAT32_SIMT, | |||||
CUDA_FLOAT32_SIMT_SPLIT_K, | |||||
CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED, | |||||
CUDA_FLOAT16_TENSOR_OP, | |||||
CUDA_FLOAT16_TENSOR_OP_SPLIT_K, | |||||
#endif | #endif | ||||
}; | }; | ||||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | ||||
@@ -188,65 +190,83 @@ private: | |||||
#endif | #endif | ||||
#if CUDA_VERSION >= 9020 | #if CUDA_VERSION >= 9020 | ||||
class MatrixMulForwardImpl::AlgoFloat32SIMT final : public AlgoBase { | |||||
class MatrixMulForwardImpl::AlgoCutlassMatrixMulBase : public AlgoBase { | |||||
public: | public: | ||||
struct AlgoParam { | struct AlgoParam { | ||||
int threadblock_m, threadblock_n, threadblock_k; | int threadblock_m, threadblock_n, threadblock_k; | ||||
int warp_m, warp_n, warp_k; | int warp_m, warp_n, warp_k; | ||||
std::string to_string() { | |||||
return ssprintf("%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n, | |||||
threadblock_k, warp_m, warp_n, warp_k); | |||||
} | |||||
int instruction_m, instruction_n, instruction_k; | |||||
AlgoParam(int threadblock_m_, int threadblock_n_, int threadblock_k_, | |||||
int warp_m_, int warp_n_, int warp_k_, int instruction_m_ = 1, | |||||
int instruction_n_ = 1, int instruction_k_ = 1) | |||||
: threadblock_m{threadblock_m_}, | |||||
threadblock_n{threadblock_n_}, | |||||
threadblock_k{threadblock_k_}, | |||||
warp_m{warp_m_}, | |||||
warp_n{warp_n_}, | |||||
warp_k{warp_k_}, | |||||
instruction_m{instruction_m_}, | |||||
instruction_n{instruction_n_}, | |||||
instruction_k{instruction_k_} {} | |||||
std::string to_string() const; | |||||
}; | }; | ||||
AlgoCutlassMatrixMulBase(AlgoParam algo_param) : m_algo_param{algo_param} {} | |||||
void exec(const ExecArgs& args) const override; | |||||
std::string param() const override { | |||||
std::string ret; | |||||
serialize_write_pod(m_algo_param, ret); | |||||
return ret; | |||||
} | |||||
protected: | |||||
virtual int min_alignment_requirement() const = 0; | |||||
virtual void do_exec(const ExecArgs& args) const = 0; | |||||
std::pair<bool, TensorLayoutArray> construct_aligned_layouts( | |||||
const SizeArgs& args) const; | |||||
int max_alignment(const SizeArgs& args) const; | |||||
AlgoParam m_algo_param; | |||||
}; | |||||
class MatrixMulForwardImpl::AlgoFloat32SIMT final | |||||
: public AlgoCutlassMatrixMulBase { | |||||
public: | |||||
AlgoFloat32SIMT(AlgoParam algo_param) | AlgoFloat32SIMT(AlgoParam algo_param) | ||||
: m_algo_param{algo_param}, | |||||
: AlgoCutlassMatrixMulBase{algo_param}, | |||||
m_name{ssprintf("CUTLASS_FLOAT32_SIMT_%s", | m_name{ssprintf("CUTLASS_FLOAT32_SIMT_%s", | ||||
m_algo_param.to_string().c_str())} {} | m_algo_param.to_string().c_str())} {} | ||||
bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
void exec(const ExecArgs& args) const override; | |||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | return AlgoAttribute::REPRODUCIBLE; | ||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT) | MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT) | ||||
std::string param() const override { | |||||
std::string ret; | |||||
serialize_write_pod(m_algo_param, ret); | |||||
return ret; | |||||
} | |||||
private: | private: | ||||
AlgoParam m_algo_param; | |||||
void do_exec(const ExecArgs& args) const override; | |||||
int min_alignment_requirement() const override { return 1; } | |||||
std::string m_name; | std::string m_name; | ||||
}; | }; | ||||
class MatrixMulForwardImpl::AlgoFloat32SIMTSplitK final : public AlgoBase { | |||||
class MatrixMulForwardImpl::AlgoFloat32SIMTSplitK final | |||||
: public AlgoCutlassMatrixMulBase { | |||||
public: | public: | ||||
using AlgoParam = MatrixMulForwardImpl::AlgoFloat32SIMT::AlgoParam; | |||||
AlgoFloat32SIMTSplitK(AlgoParam algo_param) | AlgoFloat32SIMTSplitK(AlgoParam algo_param) | ||||
: m_algo_param{algo_param}, | |||||
: AlgoCutlassMatrixMulBase{algo_param}, | |||||
m_name{ssprintf("CUTLASS_FLOAT32_SIMT_SPLIT_K_%s", | m_name{ssprintf("CUTLASS_FLOAT32_SIMT_SPLIT_K_%s", | ||||
m_algo_param.to_string().c_str())} {} | m_algo_param.to_string().c_str())} {} | ||||
bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
void exec(const ExecArgs& args) const override; | |||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE | | return AlgoAttribute::REPRODUCIBLE | | ||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | ||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K) | MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K) | ||||
std::string param() const override { | |||||
std::string ret; | |||||
serialize_write_pod(m_algo_param, ret); | |||||
return ret; | |||||
} | |||||
private: | private: | ||||
AlgoParam m_algo_param; | |||||
void do_exec(const ExecArgs& args) const override; | |||||
int min_alignment_requirement() const override { return 1; } | |||||
std::string m_name; | std::string m_name; | ||||
}; | }; | ||||
@@ -276,6 +296,56 @@ private: | |||||
int m_threadblock_n; | int m_threadblock_n; | ||||
std::string m_name; | std::string m_name; | ||||
}; | }; | ||||
class MatrixMulForwardImpl::AlgoFloat16TensorOp final | |||||
: public AlgoCutlassMatrixMulBase { | |||||
public: | |||||
AlgoFloat16TensorOp(AlgoParam algo_param) | |||||
: AlgoCutlassMatrixMulBase{algo_param}, | |||||
m_name{ssprintf("CUTLASS_FLOAT16_TENSOR_OP_h%d%d%d_%s", | |||||
m_algo_param.instruction_m, | |||||
m_algo_param.instruction_n, | |||||
m_algo_param.instruction_k, | |||||
m_algo_param.to_string().c_str())} {} | |||||
bool is_available(const SizeArgs& args) const override; | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
const char* name() const override { return m_name.c_str(); } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT16_TENSOR_OP) | |||||
private: | |||||
void do_exec(const ExecArgs& args) const override; | |||||
int min_alignment_requirement() const override { return 2; } | |||||
std::string m_name; | |||||
}; | |||||
class MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK final | |||||
: public AlgoCutlassMatrixMulBase { | |||||
public: | |||||
AlgoFloat16TensorOpSplitK(AlgoParam algo_param) | |||||
: AlgoCutlassMatrixMulBase{algo_param}, | |||||
m_name{ssprintf("CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h%d%d%d_%s", | |||||
m_algo_param.instruction_m, | |||||
m_algo_param.instruction_n, | |||||
m_algo_param.instruction_k, | |||||
m_algo_param.to_string().c_str())} {} | |||||
bool is_available(const SizeArgs& args) const override; | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
const char* name() const override { return m_name.c_str(); } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT16_TENSOR_OP_SPLIT_K) | |||||
private: | |||||
void do_exec(const ExecArgs& args) const override; | |||||
int min_alignment_requirement() const override { return 2; } | |||||
std::string m_name; | |||||
}; | |||||
#endif | #endif | ||||
class MatrixMulForwardImpl::AlgoPack : NonCopyableObj { | class MatrixMulForwardImpl::AlgoPack : NonCopyableObj { | ||||
@@ -300,6 +370,8 @@ public: | |||||
std::vector<AlgoFloat32SIMTSplitK> simt_float32_split_k; | std::vector<AlgoFloat32SIMTSplitK> simt_float32_split_k; | ||||
std::vector<AlgoFloat32SIMTGemvBatchedStrided> | std::vector<AlgoFloat32SIMTGemvBatchedStrided> | ||||
simt_float32_gemv_batched_strided; | simt_float32_gemv_batched_strided; | ||||
std::vector<AlgoFloat16TensorOp> tensorop_float16; | |||||
std::vector<AlgoFloat16TensorOpSplitK> tensorop_float16_split_k; | |||||
#endif | #endif | ||||
std::vector<AlgoBase*> all_algos; | std::vector<AlgoBase*> all_algos; | ||||
@@ -0,0 +1,154 @@ | |||||
/** | |||||
* \file dnn/src/cuda/matrix_mul/cutlass_float16_tensorop.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/cuda/cutlass/singleton.h" | |||||
#include "src/cuda/handle.h" | |||||
#include "src/cuda/matrix_mul/algos.h" | |||||
#include "src/cuda/utils.h" | |||||
#if CUDA_VERSION >= 9020 | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
bool MatrixMulForwardImpl::AlgoFloat16TensorOp::is_available( | |||||
const SizeArgs& args) const { | |||||
bool available = | |||||
args.opr->param().format == param::MatrixMul::Format::DEFAULT && | |||||
args.layout_b.dtype == dtype::Float16() && | |||||
args.layout_c.dtype == dtype::Float16(); | |||||
int n = args.layout_c.shape[1]; | |||||
auto&& device_prop = cuda::current_device_prop(); | |||||
int y_grid_limit = device_prop.maxGridSize[1]; | |||||
// limit y grid | |||||
available &= ((n + m_algo_param.threadblock_n - 1) / | |||||
m_algo_param.threadblock_n <= | |||||
y_grid_limit); | |||||
if (m_algo_param.instruction_m == 8 && m_algo_param.instruction_n == 8 && | |||||
m_algo_param.instruction_k == 4) { | |||||
available &= is_compute_capability_required(7, 0); | |||||
} else { | |||||
megdnn_assert(m_algo_param.instruction_m == 16 && | |||||
m_algo_param.instruction_n == 8 && | |||||
m_algo_param.instruction_k == 8); | |||||
available &= is_compute_capability_required(7, 5); | |||||
} | |||||
return available; | |||||
} | |||||
size_t MatrixMulForwardImpl::AlgoFloat16TensorOp::get_workspace_in_bytes( | |||||
const SizeArgs& args) const { | |||||
auto aligned = construct_aligned_layouts(args); | |||||
if (!aligned.first) | |||||
return 0_z; | |||||
const auto& layouts = aligned.second; | |||||
size_t ws_size = 0; | |||||
for (auto&& ly : layouts) { | |||||
ws_size += ly.span().dist_byte(); | |||||
} | |||||
return ws_size; | |||||
} | |||||
void MatrixMulForwardImpl::AlgoFloat16TensorOp::do_exec( | |||||
const ExecArgs& args) const { | |||||
int64_t lda = args.tensor_a.layout.stride[0], | |||||
ldb = args.tensor_b.layout.stride[0], | |||||
ldc = args.tensor_c.layout.stride[0]; | |||||
int alignment = max_alignment(args); | |||||
int min_alignment = min_alignment_requirement(); | |||||
auto&& param = args.opr->param(); | |||||
int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | |||||
k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | |||||
megdnn_assert(lda % alignment == 0 && ldb % alignment == 0 && | |||||
ldc % alignment == 0 && m % alignment == 0 && | |||||
n % alignment == 0 && k % alignment == 0 && | |||||
alignment >= min_alignment); | |||||
cutlass::gemm::GemmCoord problem_size{m, n, k}; | |||||
auto&& stream = cuda_stream(args.opr->handle()); | |||||
int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | |||||
// \note these constants (i.e. one and zero) of cutlass epilogue will be | |||||
// passed by pointers and interpreted as ElementCompute*, which will be used | |||||
// to initialize kernel parameters. So the arguments' type on the host side | |||||
// should be the same as the ElementCompute of kernel instance, otherwise | |||||
// undefined kernel bahaviors will occur caused by incorrect intepretation | |||||
// of these pointers. | |||||
float one = 1.f, zero = 0.f; | |||||
dt_float16 one_f16 = static_cast<dt_float16>(one), | |||||
zero_f16 = static_cast<dt_float16>(zero); | |||||
using namespace cutlass::library; | |||||
auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor | |||||
: LayoutTypeID::kRowMajor; | |||||
auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | |||||
: LayoutTypeID::kRowMajor; | |||||
void *host_one, *host_zero; | |||||
NumericTypeID element_accumulator; | |||||
if (param.compute_mode == param::MatrixMul::ComputeMode::DEFAULT) { | |||||
element_accumulator = NumericTypeID::kF16; | |||||
host_one = &one_f16; | |||||
host_zero = &zero_f16; | |||||
} else { | |||||
megdnn_assert(param.compute_mode == | |||||
param::MatrixMul::ComputeMode::FLOAT32); | |||||
element_accumulator = NumericTypeID::kF32; | |||||
host_one = &one; | |||||
host_zero = &zero; | |||||
} | |||||
GemmKey key{NumericTypeID::kF16, | |||||
layoutA, | |||||
NumericTypeID::kF16, | |||||
layoutB, | |||||
NumericTypeID::kF16, | |||||
LayoutTypeID::kRowMajor, | |||||
element_accumulator, | |||||
m_algo_param.threadblock_m, | |||||
m_algo_param.threadblock_n, | |||||
m_algo_param.threadblock_k, | |||||
m_algo_param.warp_m, | |||||
m_algo_param.warp_n, | |||||
m_algo_param.warp_k, | |||||
m_algo_param.instruction_m, | |||||
m_algo_param.instruction_n, | |||||
m_algo_param.instruction_k, | |||||
2, | |||||
alignment, | |||||
alignment, | |||||
SplitKMode::kNone}; | |||||
const auto& table = Singleton::get().operation_table; | |||||
megdnn_assert(table.gemm_operations.count(key) > 0, | |||||
"key not found in cutlass operation table"); | |||||
const auto& ops = table.gemm_operations.at(key); | |||||
megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", | |||||
ops.size()); | |||||
GemmArguments gemm_args{problem_size, | |||||
args.tensor_a.raw_ptr, | |||||
args.tensor_b.raw_ptr, | |||||
args.tensor_c.raw_ptr, | |||||
args.tensor_c.raw_ptr, | |||||
lda, | |||||
ldb, | |||||
ldc, | |||||
ldc, | |||||
1, | |||||
host_one, | |||||
host_zero}; | |||||
cutlass_check(ops[0]->run(&gemm_args, workspace, stream)); | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,165 @@ | |||||
/** | |||||
* \file dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/cuda/cutlass/singleton.h" | |||||
#include "src/cuda/handle.h" | |||||
#include "src/cuda/matrix_mul/algos.h" | |||||
#include "src/cuda/utils.h" | |||||
#if CUDA_VERSION >= 9020 | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
bool MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::is_available( | |||||
const SizeArgs& args) const { | |||||
auto&& param = args.opr->param(); | |||||
int n = args.layout_c.shape[1], | |||||
k = args.layout_a.shape[param.transposeA ? 0 : 1]; | |||||
bool available = | |||||
args.opr->param().format == param::MatrixMul::Format::DEFAULT && | |||||
args.layout_a.dtype == dtype::Float16() && | |||||
args.layout_b.dtype == dtype::Float16() && | |||||
args.layout_c.dtype == dtype::Float16() && k > n; | |||||
auto&& device_prop = cuda::current_device_prop(); | |||||
int y_grid_limit = device_prop.maxGridSize[1]; | |||||
// limit y grid | |||||
available &= ((n + m_algo_param.threadblock_n - 1) / | |||||
m_algo_param.threadblock_n <= | |||||
y_grid_limit); | |||||
if (m_algo_param.instruction_m == 8 && m_algo_param.instruction_n == 8 && | |||||
m_algo_param.instruction_k == 4) { | |||||
available &= is_compute_capability_required(7, 0); | |||||
} else { | |||||
megdnn_assert(m_algo_param.instruction_m == 16 && | |||||
m_algo_param.instruction_n == 8 && | |||||
m_algo_param.instruction_k == 8); | |||||
available &= is_compute_capability_required(7, 5); | |||||
} | |||||
return available; | |||||
} | |||||
size_t MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::get_workspace_in_bytes( | |||||
const SizeArgs& args) const { | |||||
auto aligned = construct_aligned_layouts(args); | |||||
auto&& param = args.opr->param(); | |||||
int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | |||||
k = args.layout_a.shape[param.transposeA ? 0 : 1]; | |||||
int split_k_slices = std::max(1, k / n); | |||||
if (!aligned.first) | |||||
return args.layout_c.dtype.size(m * n * split_k_slices); | |||||
const auto& layouts = aligned.second; | |||||
int align_m = layouts[2].shape[0], align_n = layouts[2].shape[1], | |||||
align_k = layouts[0].shape[1]; | |||||
split_k_slices = std::max(1, align_k / align_n); | |||||
size_t ws_size = | |||||
args.layout_c.dtype.size(align_m * align_n * split_k_slices); | |||||
for (auto&& ly : layouts) | |||||
ws_size += ly.span().dist_byte(); | |||||
return ws_size; | |||||
} | |||||
void MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::do_exec( | |||||
const ExecArgs& args) const { | |||||
int64_t lda = args.tensor_a.layout.stride[0], | |||||
ldb = args.tensor_b.layout.stride[0], | |||||
ldc = args.tensor_c.layout.stride[0]; | |||||
int alignment = max_alignment(args); | |||||
int min_alignment = min_alignment_requirement(); | |||||
auto&& param = args.opr->param(); | |||||
int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | |||||
k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | |||||
megdnn_assert(lda % alignment == 0 && ldb % alignment == 0 && | |||||
ldc % alignment == 0 && m % alignment == 0 && | |||||
n % alignment == 0 && k % alignment == 0 && | |||||
alignment >= min_alignment); | |||||
cutlass::gemm::GemmCoord problem_size{m, n, k}; | |||||
int split_k_slices = std::max(1, k / n); | |||||
auto&& stream = cuda_stream(args.opr->handle()); | |||||
int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | |||||
// \note these constants (i.e. one and zero) of cutlass epilogue will be | |||||
// passed by pointers and interpreted as ElementCompute*, which will be used | |||||
// to initialize kernel parameters. So the arguments' type on the host side | |||||
// should be the same as the ElementCompute of kernel instance, otherwise | |||||
// undefined kernel bahaviors will occur caused by incorrect intepretation | |||||
// of these pointers. | |||||
float one = 1.f, zero = 0.f; | |||||
dt_float16 one_f16 = static_cast<dt_float16>(one), | |||||
zero_f16 = static_cast<dt_float16>(zero); | |||||
using namespace cutlass::library; | |||||
auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor | |||||
: LayoutTypeID::kRowMajor; | |||||
auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | |||||
: LayoutTypeID::kRowMajor; | |||||
void *host_one, *host_zero; | |||||
NumericTypeID element_accumulator; | |||||
if (param.compute_mode == param::MatrixMul::ComputeMode::DEFAULT) { | |||||
element_accumulator = NumericTypeID::kF16; | |||||
host_one = &one_f16; | |||||
host_zero = &zero_f16; | |||||
} else { | |||||
megdnn_assert(param.compute_mode == | |||||
param::MatrixMul::ComputeMode::FLOAT32); | |||||
element_accumulator = NumericTypeID::kF32; | |||||
host_one = &one; | |||||
host_zero = &zero; | |||||
} | |||||
GemmKey key{NumericTypeID::kF16, | |||||
layoutA, | |||||
NumericTypeID::kF16, | |||||
layoutB, | |||||
NumericTypeID::kF16, | |||||
LayoutTypeID::kRowMajor, | |||||
element_accumulator, | |||||
m_algo_param.threadblock_m, | |||||
m_algo_param.threadblock_n, | |||||
m_algo_param.threadblock_k, | |||||
m_algo_param.warp_m, | |||||
m_algo_param.warp_n, | |||||
m_algo_param.warp_k, | |||||
m_algo_param.instruction_m, | |||||
m_algo_param.instruction_n, | |||||
m_algo_param.instruction_k, | |||||
2, | |||||
alignment, | |||||
alignment, | |||||
SplitKMode::kParallel}; | |||||
const auto& table = Singleton::get().operation_table; | |||||
megdnn_assert(table.gemm_operations.count(key) > 0, | |||||
"key not found in cutlass operation table"); | |||||
const auto& ops = table.gemm_operations.at(key); | |||||
megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", | |||||
ops.size()); | |||||
GemmArguments gemm_args{problem_size, | |||||
args.tensor_a.raw_ptr, | |||||
args.tensor_b.raw_ptr, | |||||
args.tensor_c.raw_ptr, | |||||
args.tensor_c.raw_ptr, | |||||
lda, | |||||
ldb, | |||||
ldc, | |||||
ldc, | |||||
split_k_slices, | |||||
host_one, | |||||
host_zero}; | |||||
cutlass_check(ops[0]->run(&gemm_args, workspace, stream)); | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -42,7 +42,8 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes( | |||||
return 0_z; | return 0_z; | ||||
} | } | ||||
void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { | |||||
void MatrixMulForwardImpl::AlgoFloat32SIMT::do_exec( | |||||
const ExecArgs& args) const { | |||||
int64_t lda = args.tensor_a.layout.stride[0], | int64_t lda = args.tensor_a.layout.stride[0], | ||||
ldb = args.tensor_b.layout.stride[0], | ldb = args.tensor_b.layout.stride[0], | ||||
ldc = args.tensor_c.layout.stride[0]; | ldc = args.tensor_c.layout.stride[0]; | ||||
@@ -65,12 +66,14 @@ void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { | |||||
auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | ||||
: LayoutTypeID::kRowMajor; | : LayoutTypeID::kRowMajor; | ||||
int alignment = min_alignment_requirement(); | |||||
GemmKey key{NumericTypeID::kF32, | GemmKey key{NumericTypeID::kF32, | ||||
layoutA, | layoutA, | ||||
NumericTypeID::kF32, | NumericTypeID::kF32, | ||||
layoutB, | layoutB, | ||||
NumericTypeID::kF32, | NumericTypeID::kF32, | ||||
LayoutTypeID::kRowMajor, | LayoutTypeID::kRowMajor, | ||||
NumericTypeID::kF32, | |||||
m_algo_param.threadblock_m, | m_algo_param.threadblock_m, | ||||
m_algo_param.threadblock_n, | m_algo_param.threadblock_n, | ||||
m_algo_param.threadblock_k, | m_algo_param.threadblock_k, | ||||
@@ -79,8 +82,10 @@ void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { | |||||
m_algo_param.warp_k, | m_algo_param.warp_k, | ||||
1, | 1, | ||||
1, | 1, | ||||
1, | |||||
2, | |||||
1, | |||||
2, | |||||
alignment, | |||||
alignment, | |||||
SplitKMode::kNone}; | SplitKMode::kNone}; | ||||
const Operation* op = Singleton::get().operation_table.find_op(key); | const Operation* op = Singleton::get().operation_table.find_op(key); | ||||
@@ -22,7 +22,7 @@ using namespace cuda; | |||||
bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( | bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( | ||||
const SizeArgs& args) const { | const SizeArgs& args) const { | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | |||||
int n = args.layout_c.shape[1], | |||||
k = args.layout_a.shape[param.transposeA ? 0 : 1]; | k = args.layout_a.shape[param.transposeA ? 0 : 1]; | ||||
bool available = | bool available = | ||||
args.opr->param().format == param::MatrixMul::Format::DEFAULT && | args.opr->param().format == param::MatrixMul::Format::DEFAULT && | ||||
@@ -32,8 +32,8 @@ bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( | |||||
auto&& device_prop = cuda::current_device_prop(); | auto&& device_prop = cuda::current_device_prop(); | ||||
int y_grid_limit = device_prop.maxGridSize[1]; | int y_grid_limit = device_prop.maxGridSize[1]; | ||||
// limit y grid | // limit y grid | ||||
available &= ((m + m_algo_param.threadblock_m - 1) / | |||||
m_algo_param.threadblock_m <= | |||||
available &= ((n + m_algo_param.threadblock_n - 1) / | |||||
m_algo_param.threadblock_n <= | |||||
y_grid_limit); | y_grid_limit); | ||||
return available; | return available; | ||||
} | } | ||||
@@ -47,7 +47,7 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( | |||||
return args.layout_c.dtype.size(m * n * split_k_slices); | return args.layout_c.dtype.size(m * n * split_k_slices); | ||||
} | } | ||||
void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( | |||||
void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::do_exec( | |||||
const ExecArgs& args) const { | const ExecArgs& args) const { | ||||
int64_t lda = args.tensor_a.layout.stride[0], | int64_t lda = args.tensor_a.layout.stride[0], | ||||
ldb = args.tensor_b.layout.stride[0], | ldb = args.tensor_b.layout.stride[0], | ||||
@@ -72,12 +72,14 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( | |||||
auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | ||||
: LayoutTypeID::kRowMajor; | : LayoutTypeID::kRowMajor; | ||||
int alignment = min_alignment_requirement(); | |||||
GemmKey key{NumericTypeID::kF32, | GemmKey key{NumericTypeID::kF32, | ||||
layoutA, | layoutA, | ||||
NumericTypeID::kF32, | NumericTypeID::kF32, | ||||
layoutB, | layoutB, | ||||
NumericTypeID::kF32, | NumericTypeID::kF32, | ||||
LayoutTypeID::kRowMajor, | LayoutTypeID::kRowMajor, | ||||
NumericTypeID::kF32, | |||||
m_algo_param.threadblock_m, | m_algo_param.threadblock_m, | ||||
m_algo_param.threadblock_n, | m_algo_param.threadblock_n, | ||||
m_algo_param.threadblock_k, | m_algo_param.threadblock_k, | ||||
@@ -87,7 +89,9 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( | |||||
1, | 1, | ||||
1, | 1, | ||||
1, | 1, | ||||
2, | |||||
2, | |||||
alignment, | |||||
alignment, | |||||
SplitKMode::kParallel}; | SplitKMode::kParallel}; | ||||
Operation const* op = Singleton::get().operation_table.find_op(key); | Operation const* op = Singleton::get().operation_table.find_op(key); | ||||
@@ -0,0 +1,136 @@ | |||||
/** | |||||
* \file dnn/src/cuda/matrix_mul/cutlass_matrix_mul_base.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/cuda/handle.h" | |||||
#include "src/cuda/matrix_mul/algos.h" | |||||
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | |||||
#include "src/cuda/utils.h" | |||||
#if CUDA_VERSION >= 9020 | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
std::string | |||||
MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::AlgoParam::to_string() const { | |||||
return ssprintf("%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n, | |||||
threadblock_k, warp_m, warp_n, warp_k); | |||||
} | |||||
std::pair<bool, TensorLayoutArray> | |||||
MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::construct_aligned_layouts( | |||||
const SizeArgs& args) const { | |||||
int alignment = max_alignment(args); | |||||
int min_alignment = min_alignment_requirement(); | |||||
bool aligned = alignment >= min_alignment; | |||||
if (aligned) | |||||
return std::make_pair(!aligned, TensorLayoutArray{{}}); | |||||
auto&& param = args.opr->param(); | |||||
int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | |||||
k = args.layout_a.shape[param.transposeA ? 0 : 1]; | |||||
size_t align_m = get_aligned_power2(m, min_alignment); | |||||
size_t align_n = get_aligned_power2(n, min_alignment); | |||||
size_t align_k = get_aligned_power2(k, min_alignment); | |||||
TensorLayoutArray layouts; | |||||
layouts.emplace_back(TensorLayout{{align_m, align_k}, args.layout_a.dtype}); | |||||
layouts.emplace_back(TensorLayout{{align_k, align_n}, args.layout_b.dtype}); | |||||
layouts.emplace_back(TensorLayout{{align_m, align_n}, args.layout_c.dtype}); | |||||
return std::make_pair(!aligned, std::move(layouts)); | |||||
} | |||||
void MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::exec( | |||||
const ExecArgs& args) const { | |||||
auto aligned = construct_aligned_layouts(args); | |||||
if (!aligned.first) | |||||
return do_exec(args); | |||||
const auto& layouts = aligned.second; | |||||
auto tensor_a = args.tensor_a; | |||||
auto tensor_b = args.tensor_b; | |||||
auto workspace = args.workspace; | |||||
size_t copy_size = 0; | |||||
for (const auto& ly : layouts) | |||||
copy_size += ly.span().dist_byte(); | |||||
auto&& param = args.opr->param(); | |||||
auto&& stream = cuda_stream(args.opr->handle()); | |||||
cuda_check(cudaMemsetAsync(workspace.raw_ptr, 0, copy_size, stream)); | |||||
auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>(); | |||||
auto copy_stride = [](const TensorLayout& src, TensorLayout& dst, | |||||
bool trans) { | |||||
dst.stride[0] = src.stride[0], dst.stride[1] = src.stride[1]; | |||||
if (trans) | |||||
std::swap(dst.stride[0], dst.stride[1]); | |||||
}; | |||||
copy_stride(layouts[0], tensor_a.layout, param.transposeA); | |||||
tensor_a.raw_ptr = workspace.raw_ptr; | |||||
relayout->exec(args.tensor_a, tensor_a); | |||||
workspace.raw_ptr += layouts[0].span().dist_byte(); | |||||
workspace.size -= layouts[0].span().dist_byte(); | |||||
copy_stride(layouts[1], tensor_b.layout, param.transposeB); | |||||
tensor_b.raw_ptr = workspace.raw_ptr; | |||||
relayout->exec(args.tensor_b, tensor_b); | |||||
workspace.raw_ptr += layouts[1].span().dist_byte(); | |||||
workspace.size -= layouts[1].span().dist_byte(); | |||||
decltype(tensor_a) tensor_c{workspace.raw_ptr, layouts[2]}; | |||||
workspace.raw_ptr += layouts[2].span().dist_byte(); | |||||
workspace.size -= layouts[2].span().dist_byte(); | |||||
auto&& matmul = args.opr->handle()->create_operator<MatrixMulForward>(); | |||||
matmul->param().transposeA = false; | |||||
matmul->param().transposeB = false; | |||||
matmul->param().compute_mode = args.opr->param().compute_mode; | |||||
tensor_a.layout = layouts[0]; | |||||
tensor_b.layout = layouts[1]; | |||||
ExecArgs args_{static_cast<MatrixMulForwardImpl*>(matmul.get()), tensor_a, | |||||
tensor_b, tensor_c, workspace}; | |||||
do_exec(args_); | |||||
tensor_c.layout.TensorShape::operator=(args.layout_c); | |||||
relayout->exec(tensor_c, args.tensor_c); | |||||
} | |||||
int MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::max_alignment( | |||||
const SizeArgs& args) const { | |||||
auto&& dtype_a = args.layout_a.dtype; | |||||
auto&& dtype_b = args.layout_b.dtype; | |||||
auto&& dtype_c = args.layout_c.dtype; | |||||
auto get_alignment = [](const DType& dt, int len) { | |||||
int size_bits = dt.size(1) * 8; | |||||
int align = 128; | |||||
while (align > 1) { | |||||
if ((len * size_bits) % align == 0) | |||||
break; | |||||
align = align / 2; | |||||
} | |||||
return align / size_bits; | |||||
}; | |||||
int lda = args.layout_a.stride[0], ldb = args.layout_b.stride[0], | |||||
ldc = args.layout_c.stride[0]; | |||||
auto&& param = args.opr->param(); | |||||
int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | |||||
k = args.layout_a.shape[param.transposeA ? 0 : 1]; | |||||
int max_align = get_alignment(dtype_a, lda); | |||||
max_align = std::min(get_alignment(dtype_a, m), max_align); | |||||
max_align = std::min(get_alignment(dtype_a, n), max_align); | |||||
max_align = std::min(get_alignment(dtype_a, k), max_align); | |||||
max_align = std::min(get_alignment(dtype_a, lda), max_align); | |||||
max_align = std::min(get_alignment(dtype_b, ldb), max_align); | |||||
max_align = std::min(get_alignment(dtype_c, ldc), max_align); | |||||
return max_align; | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -42,9 +42,12 @@ public: | |||||
class AlgoBFloat16; | class AlgoBFloat16; | ||||
#endif | #endif | ||||
#if CUDA_VERSION >= 9020 | #if CUDA_VERSION >= 9020 | ||||
class AlgoCutlassMatrixMulBase; | |||||
class AlgoFloat32SIMT; | class AlgoFloat32SIMT; | ||||
class AlgoFloat32SIMTSplitK; | class AlgoFloat32SIMTSplitK; | ||||
class AlgoFloat32SIMTGemvBatchedStrided; | class AlgoFloat32SIMTGemvBatchedStrided; | ||||
class AlgoFloat16TensorOp; | |||||
class AlgoFloat16TensorOpSplitK; | |||||
#endif | #endif | ||||
class AlgoPack; | class AlgoPack; | ||||
@@ -184,7 +184,8 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | |||||
const ExecutionPolicyAlgoName& algo, | const ExecutionPolicyAlgoName& algo, | ||||
param::MatrixMul::Format format, size_t nbase, | param::MatrixMul::Format format, size_t nbase, | ||||
float eps, std::vector<TestArg>&& user_args, | float eps, std::vector<TestArg>&& user_args, | ||||
bool force_deduce_dst) { | |||||
bool force_deduce_dst, | |||||
param::MatrixMul::ComputeMode compute_mode) { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv()); | megdnn_assert(A_dtype.enumv() == B_dtype.enumv()); | ||||
Checker<Opr> checker(handle); | Checker<Opr> checker(handle); | ||||
checker.set_force_deduce_dst(force_deduce_dst); | checker.set_force_deduce_dst(force_deduce_dst); | ||||
@@ -261,6 +262,7 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | |||||
Param param; | Param param; | ||||
param.transposeA = arg.mask & 0x1; | param.transposeA = arg.mask & 0x1; | ||||
param.transposeB = arg.mask & 0x2; | param.transposeB = arg.mask & 0x2; | ||||
param.compute_mode = compute_mode; | |||||
param.format = format; | param.format = format; | ||||
checker.set_dtype(0, A_dtype) | checker.set_dtype(0, A_dtype) | ||||
.set_dtype(1, B_dtype) | .set_dtype(1, B_dtype) | ||||
@@ -69,7 +69,9 @@ void check_matrix_mul( | |||||
const ExecutionPolicyAlgoName& algo = {"", {}}, | const ExecutionPolicyAlgoName& algo = {"", {}}, | ||||
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, | param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, | ||||
size_t nbase = 8, float eps = 1e-3, std::vector<TestArg>&& args = {}, | size_t nbase = 8, float eps = 1e-3, std::vector<TestArg>&& args = {}, | ||||
bool force_deduce_dst = true); | |||||
bool force_deduce_dst = true, | |||||
param::MatrixMul::ComputeMode compute_mode = | |||||
param::MatrixMul::ComputeMode::DEFAULT); | |||||
void check_matrix_mul( | void check_matrix_mul( | ||||
DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, | DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, | ||||
@@ -21,6 +21,7 @@ | |||||
#include "test/cuda/fixture.h" | #include "test/cuda/fixture.h" | ||||
#include "test/cuda/utils.h" | #include "test/cuda/utils.h" | ||||
#define MEGDNN_WITH_BENCHMARK 1 | |||||
#if CUDA_VERSION >= 9020 | #if CUDA_VERSION >= 9020 | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace test { | namespace test { | ||||
@@ -215,6 +216,14 @@ std::vector<BenchArgs> get_feat_model_args() { | |||||
return args; | return args; | ||||
} | } | ||||
std::vector<BenchArgs> get_f16_feat_model_args() { | |||||
std::vector<BenchArgs> args; | |||||
args.emplace_back(BenchArgs{128, 9216, 9216}); | |||||
args.emplace_back(BenchArgs{128, 6400, 6400}); | |||||
args.emplace_back(BenchArgs{128, 5184, 5184}); | |||||
return args; | |||||
} | |||||
void benchmark_matrix_mul( | void benchmark_matrix_mul( | ||||
Handle* handle, const std::vector<BenchArgs>& args, DType A_dtype, | Handle* handle, const std::vector<BenchArgs>& args, DType A_dtype, | ||||
DType B_dtype, DType C_dtype, const char* algo = nullptr, | DType B_dtype, DType C_dtype, const char* algo = nullptr, | ||||
@@ -364,6 +373,82 @@ MEGDNN_FOREACH_CUTLASS_KERNEL(cb) | |||||
#undef cb | #undef cb | ||||
#undef MEGDNN_FOREACH_CUTLASS_KERNEL | #undef MEGDNN_FOREACH_CUTLASS_KERNEL | ||||
#define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \ | |||||
cb(1, 256, 128, 32, 64, 64, 32, 8, 8, 4); \ | |||||
cb(2, 128, 256, 32, 64, 64, 32, 8, 8, 4); \ | |||||
cb(3, 128, 128, 32, 64, 64, 32, 8, 8, 4); | |||||
#define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \ | |||||
TEST_F(CUDA, CUTLASS_F16_884_GEMM_##name) { \ | |||||
require_compute_capability(7, 0); \ | |||||
matrix_mul::check_matrix_mul<MatrixMulForward>( \ | |||||
dtype::Float16(), dtype::Float16(), dtype::Float16(), \ | |||||
handle_cuda(), \ | |||||
"CUTLASS_FLOAT16_TENSOR_OP_h" #im #in #ik "_" #tbm "X" #tbn \ | |||||
"X" #tbk "_" #wm "X" #wn "X" #wk, \ | |||||
param::MatrixMul::Format::DEFAULT, 8, 1e-2, \ | |||||
matrix_mul::get_matmul_args()); \ | |||||
} | |||||
MEGDNN_FOREACH_CUTLASS_KERNEL(cb) | |||||
#undef cb | |||||
#define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \ | |||||
TEST_F(CUDA, CUTLASS_F16_884_GEMM_SPLIT_K_##name) { \ | |||||
require_compute_capability(7, 0); \ | |||||
matrix_mul::check_matrix_mul<MatrixMulForward>( \ | |||||
dtype::Float16(), dtype::Float16(), dtype::Float16(), \ | |||||
handle_cuda(), \ | |||||
"CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h" #im #in #ik "_" #tbm \ | |||||
"X" #tbn "X" #tbk "_" #wm "X" #wn "X" #wk, \ | |||||
param::MatrixMul::Format::DEFAULT, 8, 1e-3, \ | |||||
matrix_mul::get_matmul_args_split_k(), true, \ | |||||
param::MatrixMul::ComputeMode::FLOAT32); \ | |||||
} | |||||
MEGDNN_FOREACH_CUTLASS_KERNEL(cb) | |||||
#undef cb | |||||
#undef MEGDNN_FOREACH_CUTLASS_KERNEL | |||||
#define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \ | |||||
cb(1, 256, 128, 32, 64, 64, 32, 16, 8, 8); \ | |||||
cb(2, 128, 256, 32, 64, 64, 32, 16, 8, 8); \ | |||||
cb(3, 128, 128, 32, 64, 64, 32, 16, 8, 8); | |||||
#define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \ | |||||
TEST_F(CUDA, CUTLASS_F16_1688_GEMM_##name) { \ | |||||
require_compute_capability(7, 5); \ | |||||
matrix_mul::check_matrix_mul<MatrixMulForward>( \ | |||||
dtype::Float16(), dtype::Float16(), dtype::Float16(), \ | |||||
handle_cuda(), \ | |||||
"CUTLASS_FLOAT16_TENSOR_OP_h" #im #in #ik "_" #tbm "X" #tbn \ | |||||
"X" #tbk "_" #wm "X" #wn "X" #wk, \ | |||||
param::MatrixMul::Format::DEFAULT, 8, 1e-2, \ | |||||
matrix_mul::get_matmul_args(), true, \ | |||||
param::MatrixMul::ComputeMode::FLOAT32); \ | |||||
} | |||||
MEGDNN_FOREACH_CUTLASS_KERNEL(cb) | |||||
#undef cb | |||||
#define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \ | |||||
TEST_F(CUDA, CUTLASS_F16_1688_GEMM_SPLIT_K_##name) { \ | |||||
require_compute_capability(7, 5); \ | |||||
matrix_mul::check_matrix_mul<MatrixMulForward>( \ | |||||
dtype::Float16(), dtype::Float16(), dtype::Float16(), \ | |||||
handle_cuda(), \ | |||||
"CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h" #im #in #ik "_" #tbm \ | |||||
"X" #tbn "X" #tbk "_" #wm "X" #wn "X" #wk, \ | |||||
param::MatrixMul::Format::DEFAULT, 8, 1e-3, \ | |||||
matrix_mul::get_matmul_args_split_k()); \ | |||||
} | |||||
MEGDNN_FOREACH_CUTLASS_KERNEL(cb) | |||||
#undef cb | |||||
#undef MEGDNN_FOREACH_CUTLASS_KERNEL | |||||
#if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL) { | TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL) { | ||||
benchmark_matrix_mul(handle_cuda(), get_square_matmul_args(), | benchmark_matrix_mul(handle_cuda(), get_square_matmul_args(), | ||||
@@ -376,6 +461,12 @@ TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL_FEAT) { | |||||
dtype::Float32(), dtype::Float32(), | dtype::Float32(), dtype::Float32(), | ||||
"CUTLASS_FLOAT32_SIMT"); | "CUTLASS_FLOAT32_SIMT"); | ||||
} | } | ||||
TEST_F(CUDA, BENCHMARK_CUTLASS_F16_MATMUL_FEAT) { | |||||
benchmark_matrix_mul(handle_cuda(), get_f16_feat_model_args(), | |||||
dtype::Float16(), dtype::Float16(), dtype::Float16(), | |||||
"CUTLASS_FLOAT16_TENSOR_OP"); | |||||
} | |||||
#endif | #endif | ||||
} // namespace test | } // namespace test | ||||
} // namespace megdnn | } // namespace megdnn | ||||