Browse Source

feat(dnn/cuda): add tensorcore matmul for fp16 data type

GitOrigin-RevId: 025c591f75
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
336761253d
20 changed files with 1168 additions and 49 deletions
  1. +2
    -0
      dnn/scripts/cutlass_generator/BUILD
  2. +22
    -3
      dnn/scripts/cutlass_generator/gemm_operation.py
  3. +2
    -0
      dnn/scripts/cutlass_generator/gen_list.py
  4. +134
    -4
      dnn/scripts/cutlass_generator/generator.py
  5. +291
    -1
      dnn/scripts/cutlass_generator/list.bzl
  6. +2
    -0
      dnn/src/CMakeLists.txt
  7. +4
    -0
      dnn/src/cuda/cutlass/initialize_all.cu
  8. +4
    -0
      dnn/src/cuda/cutlass/operation_table.cpp
  9. +14
    -1
      dnn/src/cuda/cutlass/operation_table.h
  10. +20
    -1
      dnn/src/cuda/matrix_mul/algos.cpp
  11. +101
    -29
      dnn/src/cuda/matrix_mul/algos.h
  12. +154
    -0
      dnn/src/cuda/matrix_mul/cutlass_float16_tensorop.cpp
  13. +165
    -0
      dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp
  14. +8
    -3
      dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp
  15. +9
    -5
      dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp
  16. +136
    -0
      dnn/src/cuda/matrix_mul/cutlass_matrix_mul_base.cpp
  17. +3
    -0
      dnn/src/cuda/matrix_mul/opr_impl.h
  18. +3
    -1
      dnn/test/common/matrix_mul.cpp
  19. +3
    -1
      dnn/test/common/matrix_mul.h
  20. +91
    -0
      dnn/test/cuda/cutlass_matmul.cpp

+ 2
- 0
dnn/scripts/cutlass_generator/BUILD View File

@@ -5,6 +5,8 @@ genrule(
outs = cutlass_gen_list, outs = cutlass_gen_list,
cmd = """GEN=$(location //brain/megbrain/dnn/scripts/cutlass_generator:generator.py) cmd = """GEN=$(location //brain/megbrain/dnn/scripts/cutlass_generator:generator.py)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type simt $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type simt $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type tensorop884 $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type tensorop1688 $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemv --type simt $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemv --type simt $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type simt $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type simt $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type simt $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type simt $(@D)


+ 22
- 3
dnn/scripts/cutlass_generator/gemm_operation.py View File

@@ -252,7 +252,8 @@ def GeneratesGemm(tile, data_type, layout_a, layout_b, layout_c, min_cc, align_a
if tile.math_instruction.element_accumulator == DataType.s32: if tile.math_instruction.element_accumulator == DataType.s32:
epilogues = [EpilogueFunctor.LinearCombinationClamp] epilogues = [EpilogueFunctor.LinearCombinationClamp]
else: else:
assert tile.math_instruction.element_accumulator == DataType.f32
assert tile.math_instruction.element_accumulator == DataType.f32 or \
tile.math_instruction.element_accumulator == DataType.f16
epilogues = [EpilogueFunctor.LinearCombination] epilogues = [EpilogueFunctor.LinearCombination]


for epilogue in epilogues: for epilogue in epilogues:
@@ -799,7 +800,22 @@ class EmitGemmSplitKParallelInstance:
${epilogue_vector_length}, ${epilogue_vector_length},
${element_accumulator}, ${element_accumulator},
${element_epilogue} ${element_epilogue}
>
>,
cutlass::epilogue::thread::Convert<
${element_accumulator},
${epilogue_vector_length},
${element_accumulator}
>,
cutlass::reduction::thread::ReduceAdd<
${element_accumulator},
${element_accumulator},
${epilogue_vector_length}
>,
cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle,
${stages},
${align_a},
${align_b},
${math_operation}
>; >;
""" """
def emit(self, operation): def emit(self, operation):
@@ -831,7 +847,10 @@ class EmitGemmSplitKParallelInstance:
'epilogue_vector_length': str(epilogue_vector_length), 'epilogue_vector_length': str(epilogue_vector_length),
'element_epilogue': str(DataTypeTag[operation.element_epilogue]), 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
'stages': str(operation.tile_description.stages),
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
'align_a': str(operation.A.alignment),
'align_b': str(operation.B.alignment),
} }


return SubstituteTemplate(self.template, values) return SubstituteTemplate(self.template, values)


+ 2
- 0
dnn/scripts/cutlass_generator/gen_list.py View File

@@ -32,6 +32,8 @@ if __name__ == "__main__":
f.write("# Generated by dnn/scripts/cutlass_generator/gen_list.py\n\n") f.write("# Generated by dnn/scripts/cutlass_generator/gen_list.py\n\n")
f.write("cutlass_gen_list = [\n") f.write("cutlass_gen_list = [\n")
write_op_list(f, "gemm", "simt") write_op_list(f, "gemm", "simt")
write_op_list(f, "gemm", "tensorop1688")
write_op_list(f, "gemm", "tensorop884")
write_op_list(f, "gemv", "simt") write_op_list(f, "gemv", "simt")
write_op_list(f, "deconv", "simt") write_op_list(f, "deconv", "simt")
write_op_list(f, "conv2d", "simt") write_op_list(f, "conv2d", "simt")


+ 134
- 4
dnn/scripts/cutlass_generator/generator.py View File

@@ -597,6 +597,131 @@ def GenerateGemv_Simt(args):
return operations return operations


# #
def GeneratesGemm_TensorOp_1688(args):
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
]
math_instructions = [
MathInstruction( \
[16, 8, 8], \
DataType.f16, DataType.f16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add),
MathInstruction( \
[16, 8, 8], \
DataType.f16, DataType.f16, DataType.f16, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add),
]

min_cc = 75
max_cc = 1024

alignment_constraints = [8, 4, 2,
#1
]

operations = []
for math_inst in math_instructions:
for layout in layouts:
for align in alignment_constraints:
tile_descriptions = [
TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
## comment some configuration to reduce compilation time and binary size
# TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
# TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
# TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
]
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_a,
math_inst.element_accumulator,
]
for tile in tile_descriptions:
operations += GeneratesGemm(tile, \
data_type, \
layout[0], \
layout[1], \
layout[2], \
min_cc, \
align * 16, \
align * 16, \
align * 16)
return operations

#
def GeneratesGemm_TensorOp_884(args):
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
]

math_instructions = [
MathInstruction( \
[8, 8, 4], \
DataType.f16, DataType.f16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add),
MathInstruction( \
[8, 8, 4], \
DataType.f16, DataType.f16, DataType.f16, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add),
]

min_cc = 70
max_cc = 75

alignment_constraints = [8, 4, 2,
# 1
]

operations = []
for math_inst in math_instructions:
for layout in layouts:
for align in alignment_constraints:
tile_descriptions = [
TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
## comment some configuration to reduce compilation time and binary size
# TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
# TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
# TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
]
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_a,
math_inst.element_accumulator,
]
for tile in tile_descriptions:
operations += GeneratesGemm(tile, \
data_type, \
layout[0], \
layout[1], \
layout[2], \
min_cc, \
align * 16, \
align * 16, \
align * 16)
return operations

#
def GenerateConv2dOperations(args): def GenerateConv2dOperations(args):
if args.type == "simt": if args.type == "simt":
return GenerateConv2d_Simt(args) return GenerateConv2d_Simt(args)
@@ -613,9 +738,14 @@ def GenerateDeconvOperations(args):
return GenerateDeconv_Simt(args) return GenerateDeconv_Simt(args)


def GenerateGemmOperations(args): def GenerateGemmOperations(args):
assert args.type == "simt", "operation gemm only support" \
"simt. (got:{})".format(args.type)
return GenerateGemm_Simt(args)
if args.type == "tensorop884":
return GeneratesGemm_TensorOp_884(args)
elif args.type == "tensorop1688":
return GeneratesGemm_TensorOp_1688(args)
else:
assert args.type == "simt", "operation gemm only support" \
"simt. (got:{})".format(args.type)
return GenerateGemm_Simt(args)


def GenerateGemvOperations(args): def GenerateGemvOperations(args):
assert args.type == "simt", "operation gemv only support" \ assert args.type == "simt", "operation gemv only support" \
@@ -631,7 +761,7 @@ if __name__ == "__main__":
parser.add_argument("--operations", type=str, choices=['gemm', 'gemv', 'conv2d', 'deconv'], parser.add_argument("--operations", type=str, choices=['gemm', 'gemv', 'conv2d', 'deconv'],
required=True, help="Specifies the operation to generate (gemm, gemv, conv2d, deconv)") required=True, help="Specifies the operation to generate (gemm, gemv, conv2d, deconv)")
parser.add_argument("output", type=str, help="output directory for CUTLASS kernel files") parser.add_argument("output", type=str, help="output directory for CUTLASS kernel files")
parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832'],
parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832', 'tensorop884', 'tensorop1688'],
default='simt', help="kernel type of CUTLASS kernel generator") default='simt', help="kernel type of CUTLASS kernel generator")


gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl" gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"


+ 291
- 1
dnn/scripts/cutlass_generator/list.bzl View File

@@ -138,6 +138,296 @@ cutlass_gen_list = [
"cutlass_simt_sgemm_256x64_8x2_tt_align1.cu", "cutlass_simt_sgemm_256x64_8x2_tt_align1.cu",
"cutlass_simt_sgemm_split_k_parallel_256x64_8x2_tt_align1.cu", "cutlass_simt_sgemm_split_k_parallel_256x64_8x2_tt_align1.cu",
"all_gemm_simt_operations.cu", "all_gemm_simt_operations.cu",
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align8.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nn_align8.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nn_align8.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nn_align8.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nn_align8.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nn_align8.cu",
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align4.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nn_align4.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nn_align4.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nn_align4.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nn_align4.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nn_align4.cu",
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align2.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nn_align2.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nn_align2.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nn_align2.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nn_align2.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nn_align2.cu",
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nt_align8.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nt_align8.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nt_align8.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nt_align8.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nt_align8.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nt_align8.cu",
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nt_align4.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nt_align4.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nt_align4.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nt_align4.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nt_align4.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nt_align4.cu",
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nt_align2.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nt_align2.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nt_align2.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nt_align2.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nt_align2.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nt_align2.cu",
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tn_align8.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tn_align8.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tn_align8.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tn_align8.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tn_align8.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tn_align8.cu",
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tn_align4.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tn_align4.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tn_align4.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tn_align4.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tn_align4.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tn_align4.cu",
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tn_align2.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tn_align2.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tn_align2.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tn_align2.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tn_align2.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tn_align2.cu",
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tt_align8.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tt_align8.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tt_align8.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tt_align8.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tt_align8.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tt_align8.cu",
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tt_align4.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tt_align4.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tt_align4.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tt_align4.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tt_align4.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tt_align4.cu",
"cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tt_align2.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tt_align2.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tt_align2.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tt_align2.cu",
"cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tt_align2.cu",
"cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tt_align2.cu",
"cutlass_tensorop_h1688gemm_256x128_32x2_nn_align8.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nn_align8.cu",
"cutlass_tensorop_h1688gemm_128x256_32x2_nn_align8.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nn_align8.cu",
"cutlass_tensorop_h1688gemm_128x128_32x2_nn_align8.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nn_align8.cu",
"cutlass_tensorop_h1688gemm_256x128_32x2_nn_align4.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nn_align4.cu",
"cutlass_tensorop_h1688gemm_128x256_32x2_nn_align4.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nn_align4.cu",
"cutlass_tensorop_h1688gemm_128x128_32x2_nn_align4.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nn_align4.cu",
"cutlass_tensorop_h1688gemm_256x128_32x2_nn_align2.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nn_align2.cu",
"cutlass_tensorop_h1688gemm_128x256_32x2_nn_align2.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nn_align2.cu",
"cutlass_tensorop_h1688gemm_128x128_32x2_nn_align2.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nn_align2.cu",
"cutlass_tensorop_h1688gemm_256x128_32x2_nt_align8.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nt_align8.cu",
"cutlass_tensorop_h1688gemm_128x256_32x2_nt_align8.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nt_align8.cu",
"cutlass_tensorop_h1688gemm_128x128_32x2_nt_align8.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nt_align8.cu",
"cutlass_tensorop_h1688gemm_256x128_32x2_nt_align4.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nt_align4.cu",
"cutlass_tensorop_h1688gemm_128x256_32x2_nt_align4.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nt_align4.cu",
"cutlass_tensorop_h1688gemm_128x128_32x2_nt_align4.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nt_align4.cu",
"cutlass_tensorop_h1688gemm_256x128_32x2_nt_align2.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nt_align2.cu",
"cutlass_tensorop_h1688gemm_128x256_32x2_nt_align2.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nt_align2.cu",
"cutlass_tensorop_h1688gemm_128x128_32x2_nt_align2.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nt_align2.cu",
"cutlass_tensorop_h1688gemm_256x128_32x2_tn_align8.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tn_align8.cu",
"cutlass_tensorop_h1688gemm_128x256_32x2_tn_align8.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tn_align8.cu",
"cutlass_tensorop_h1688gemm_128x128_32x2_tn_align8.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tn_align8.cu",
"cutlass_tensorop_h1688gemm_256x128_32x2_tn_align4.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tn_align4.cu",
"cutlass_tensorop_h1688gemm_128x256_32x2_tn_align4.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tn_align4.cu",
"cutlass_tensorop_h1688gemm_128x128_32x2_tn_align4.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tn_align4.cu",
"cutlass_tensorop_h1688gemm_256x128_32x2_tn_align2.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tn_align2.cu",
"cutlass_tensorop_h1688gemm_128x256_32x2_tn_align2.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tn_align2.cu",
"cutlass_tensorop_h1688gemm_128x128_32x2_tn_align2.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tn_align2.cu",
"cutlass_tensorop_h1688gemm_256x128_32x2_tt_align8.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tt_align8.cu",
"cutlass_tensorop_h1688gemm_128x256_32x2_tt_align8.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tt_align8.cu",
"cutlass_tensorop_h1688gemm_128x128_32x2_tt_align8.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tt_align8.cu",
"cutlass_tensorop_h1688gemm_256x128_32x2_tt_align4.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tt_align4.cu",
"cutlass_tensorop_h1688gemm_128x256_32x2_tt_align4.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tt_align4.cu",
"cutlass_tensorop_h1688gemm_128x128_32x2_tt_align4.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tt_align4.cu",
"cutlass_tensorop_h1688gemm_256x128_32x2_tt_align2.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tt_align2.cu",
"cutlass_tensorop_h1688gemm_128x256_32x2_tt_align2.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tt_align2.cu",
"cutlass_tensorop_h1688gemm_128x128_32x2_tt_align2.cu",
"cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tt_align2.cu",
"all_gemm_tensorop1688_operations.cu",
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nn_align8.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nn_align8.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nn_align8.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nn_align8.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nn_align8.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nn_align8.cu",
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nn_align4.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nn_align4.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nn_align4.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nn_align4.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nn_align4.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nn_align4.cu",
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nn_align2.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nn_align2.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nn_align2.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nn_align2.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nn_align2.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nn_align2.cu",
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nt_align8.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nt_align8.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nt_align8.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nt_align8.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nt_align8.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nt_align8.cu",
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nt_align4.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nt_align4.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nt_align4.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nt_align4.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nt_align4.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nt_align4.cu",
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nt_align2.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nt_align2.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nt_align2.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nt_align2.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nt_align2.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nt_align2.cu",
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tn_align8.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tn_align8.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tn_align8.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tn_align8.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tn_align8.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tn_align8.cu",
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tn_align4.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tn_align4.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tn_align4.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tn_align4.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tn_align4.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tn_align4.cu",
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tn_align2.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tn_align2.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tn_align2.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tn_align2.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tn_align2.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tn_align2.cu",
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tt_align8.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tt_align8.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tt_align8.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tt_align8.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tt_align8.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tt_align8.cu",
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tt_align4.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tt_align4.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tt_align4.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tt_align4.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tt_align4.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tt_align4.cu",
"cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tt_align2.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tt_align2.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tt_align2.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tt_align2.cu",
"cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tt_align2.cu",
"cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tt_align2.cu",
"cutlass_tensorop_h884gemm_256x128_32x2_nn_align8.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nn_align8.cu",
"cutlass_tensorop_h884gemm_128x256_32x2_nn_align8.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nn_align8.cu",
"cutlass_tensorop_h884gemm_128x128_32x2_nn_align8.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nn_align8.cu",
"cutlass_tensorop_h884gemm_256x128_32x2_nn_align4.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nn_align4.cu",
"cutlass_tensorop_h884gemm_128x256_32x2_nn_align4.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nn_align4.cu",
"cutlass_tensorop_h884gemm_128x128_32x2_nn_align4.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nn_align4.cu",
"cutlass_tensorop_h884gemm_256x128_32x2_nn_align2.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nn_align2.cu",
"cutlass_tensorop_h884gemm_128x256_32x2_nn_align2.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nn_align2.cu",
"cutlass_tensorop_h884gemm_128x128_32x2_nn_align2.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nn_align2.cu",
"cutlass_tensorop_h884gemm_256x128_32x2_nt_align8.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nt_align8.cu",
"cutlass_tensorop_h884gemm_128x256_32x2_nt_align8.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nt_align8.cu",
"cutlass_tensorop_h884gemm_128x128_32x2_nt_align8.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nt_align8.cu",
"cutlass_tensorop_h884gemm_256x128_32x2_nt_align4.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nt_align4.cu",
"cutlass_tensorop_h884gemm_128x256_32x2_nt_align4.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nt_align4.cu",
"cutlass_tensorop_h884gemm_128x128_32x2_nt_align4.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nt_align4.cu",
"cutlass_tensorop_h884gemm_256x128_32x2_nt_align2.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nt_align2.cu",
"cutlass_tensorop_h884gemm_128x256_32x2_nt_align2.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nt_align2.cu",
"cutlass_tensorop_h884gemm_128x128_32x2_nt_align2.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nt_align2.cu",
"cutlass_tensorop_h884gemm_256x128_32x2_tn_align8.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tn_align8.cu",
"cutlass_tensorop_h884gemm_128x256_32x2_tn_align8.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tn_align8.cu",
"cutlass_tensorop_h884gemm_128x128_32x2_tn_align8.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tn_align8.cu",
"cutlass_tensorop_h884gemm_256x128_32x2_tn_align4.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tn_align4.cu",
"cutlass_tensorop_h884gemm_128x256_32x2_tn_align4.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tn_align4.cu",
"cutlass_tensorop_h884gemm_128x128_32x2_tn_align4.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tn_align4.cu",
"cutlass_tensorop_h884gemm_256x128_32x2_tn_align2.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tn_align2.cu",
"cutlass_tensorop_h884gemm_128x256_32x2_tn_align2.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tn_align2.cu",
"cutlass_tensorop_h884gemm_128x128_32x2_tn_align2.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tn_align2.cu",
"cutlass_tensorop_h884gemm_256x128_32x2_tt_align8.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tt_align8.cu",
"cutlass_tensorop_h884gemm_128x256_32x2_tt_align8.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tt_align8.cu",
"cutlass_tensorop_h884gemm_128x128_32x2_tt_align8.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tt_align8.cu",
"cutlass_tensorop_h884gemm_256x128_32x2_tt_align4.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tt_align4.cu",
"cutlass_tensorop_h884gemm_128x256_32x2_tt_align4.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tt_align4.cu",
"cutlass_tensorop_h884gemm_128x128_32x2_tt_align4.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tt_align4.cu",
"cutlass_tensorop_h884gemm_256x128_32x2_tt_align2.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tt_align2.cu",
"cutlass_tensorop_h884gemm_128x256_32x2_tt_align2.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tt_align2.cu",
"cutlass_tensorop_h884gemm_128x128_32x2_tt_align2.cu",
"cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tt_align2.cu",
"all_gemm_tensorop884_operations.cu",
"cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4.cu", "cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4.cu",
"cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2.cu", "cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2.cu",
"cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1.cu", "cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1.cu",
@@ -646,4 +936,4 @@ cutlass_gen_list = [
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu",
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu",
"all_conv2d_tensorop8832_operations.cu", "all_conv2d_tensorop8832_operations.cu",
]
]

+ 2
- 0
dnn/src/CMakeLists.txt View File

@@ -151,6 +151,8 @@ if(MGE_WITH_CUDA)
set(${gen_files} "${${gen_files}}" PARENT_SCOPE) set(${gen_files} "${${gen_files}}" PARENT_SCOPE)
endfunction() endfunction()
gen_cutlass_kimpl(gemm simt CUTLASS_SOURCES) gen_cutlass_kimpl(gemm simt CUTLASS_SOURCES)
gen_cutlass_kimpl(gemm tensorop884 CUTLASS_SOURCES)
gen_cutlass_kimpl(gemm tensorop1688 CUTLASS_SOURCES)
gen_cutlass_kimpl(gemv simt CUTLASS_SOURCES) gen_cutlass_kimpl(gemv simt CUTLASS_SOURCES)
gen_cutlass_kimpl(deconv simt CUTLASS_SOURCES) gen_cutlass_kimpl(deconv simt CUTLASS_SOURCES)
gen_cutlass_kimpl(conv2d simt CUTLASS_SOURCES) gen_cutlass_kimpl(conv2d simt CUTLASS_SOURCES)


+ 4
- 0
dnn/src/cuda/cutlass/initialize_all.cu View File

@@ -49,6 +49,8 @@ namespace library {
(__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)


void initialize_all_gemm_simt_operations(Manifest& manifest); void initialize_all_gemm_simt_operations(Manifest& manifest);
void initialize_all_gemm_tensorop884_operations(Manifest& manifest);
void initialize_all_gemm_tensorop1688_operations(Manifest& manifest);
void initialize_all_conv2d_simt_operations(Manifest& manifest); void initialize_all_conv2d_simt_operations(Manifest& manifest);
void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest); void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest);
void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest); void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest);
@@ -56,6 +58,8 @@ void initialize_all_deconv_simt_operations(Manifest& manifest);


void initialize_all(Manifest& manifest) { void initialize_all(Manifest& manifest) {
initialize_all_gemm_simt_operations(manifest); initialize_all_gemm_simt_operations(manifest);
initialize_all_gemm_tensorop884_operations(manifest);
initialize_all_gemm_tensorop1688_operations(manifest);
initialize_all_conv2d_simt_operations(manifest); initialize_all_conv2d_simt_operations(manifest);
initialize_all_conv2d_tensorop8816_operations(manifest); initialize_all_conv2d_tensorop8816_operations(manifest);
initialize_all_conv2d_tensorop8832_operations(manifest); initialize_all_conv2d_tensorop8832_operations(manifest);


+ 4
- 0
dnn/src/cuda/cutlass/operation_table.cpp View File

@@ -55,6 +55,8 @@ GemmKey get_gemm_key_from_desc(const GemmDescription& desc) {
key.layout_B = desc.B.layout; key.layout_B = desc.B.layout;
key.element_C = desc.C.element; key.element_C = desc.C.element;
key.layout_C = desc.C.layout; key.layout_C = desc.C.layout;
key.element_accumulator =
desc.tile_description.math_instruction.element_accumulator;


key.threadblock_shape_m = desc.tile_description.threadblock_shape.m(); key.threadblock_shape_m = desc.tile_description.threadblock_shape.m();
key.threadblock_shape_n = desc.tile_description.threadblock_shape.n(); key.threadblock_shape_n = desc.tile_description.threadblock_shape.n();
@@ -75,6 +77,8 @@ GemmKey get_gemm_key_from_desc(const GemmDescription& desc) {
desc.tile_description.math_instruction.instruction_shape.k(); desc.tile_description.math_instruction.instruction_shape.k();


key.stages = desc.stages; key.stages = desc.stages;
key.alignment_A = desc.A.alignment;
key.alignment_B = desc.B.alignment;
key.split_k_mode = desc.split_k_mode; key.split_k_mode = desc.split_k_mode;


return key; return key;


+ 14
- 1
dnn/src/cuda/cutlass/operation_table.h View File

@@ -77,6 +77,7 @@ struct GemmKey {
LayoutTypeID layout_B; LayoutTypeID layout_B;
NumericTypeID element_C; NumericTypeID element_C;
LayoutTypeID layout_C; LayoutTypeID layout_C;
NumericTypeID element_accumulator;


int threadblock_shape_m; int threadblock_shape_m;
int threadblock_shape_n; int threadblock_shape_n;
@@ -91,12 +92,15 @@ struct GemmKey {
int instruction_shape_k; int instruction_shape_k;


int stages; int stages;
int alignment_A;
int alignment_B;
SplitKMode split_k_mode; SplitKMode split_k_mode;


inline bool operator==(GemmKey const& rhs) const { inline bool operator==(GemmKey const& rhs) const {
return (element_A == rhs.element_A) && (layout_A == rhs.layout_A) && return (element_A == rhs.element_A) && (layout_A == rhs.layout_A) &&
(element_B == rhs.element_B) && (layout_B == rhs.layout_B) && (element_B == rhs.element_B) && (layout_B == rhs.layout_B) &&
(element_C == rhs.element_C) && (layout_C == rhs.layout_C) && (element_C == rhs.element_C) && (layout_C == rhs.layout_C) &&
(element_accumulator == rhs.element_accumulator) &&
(threadblock_shape_m == rhs.threadblock_shape_m) && (threadblock_shape_m == rhs.threadblock_shape_m) &&
(threadblock_shape_n == rhs.threadblock_shape_n) && (threadblock_shape_n == rhs.threadblock_shape_n) &&
(threadblock_shape_k == rhs.threadblock_shape_k) && (threadblock_shape_k == rhs.threadblock_shape_k) &&
@@ -106,7 +110,9 @@ struct GemmKey {
(instruction_shape_m == rhs.instruction_shape_m) && (instruction_shape_m == rhs.instruction_shape_m) &&
(instruction_shape_n == rhs.instruction_shape_n) && (instruction_shape_n == rhs.instruction_shape_n) &&
(instruction_shape_k == rhs.instruction_shape_k) && (instruction_shape_k == rhs.instruction_shape_k) &&
(stages == rhs.stages) && (split_k_mode == rhs.split_k_mode);
(stages == rhs.stages) && (alignment_A == rhs.alignment_A) &&
(alignment_B == rhs.alignment_B) &&
(split_k_mode == rhs.split_k_mode);
} }


inline bool operator!=(GemmKey const& rhs) const { return !(*this == rhs); } inline bool operator!=(GemmKey const& rhs) const { return !(*this == rhs); }
@@ -130,10 +136,13 @@ struct GemmKey {
"\n layout_B: " + to_string(layout_B) + "\n layout_B: " + to_string(layout_B) +
"\n element_C: " + to_string(element_C) + "\n element_C: " + to_string(element_C) +
"\n layout_C: " + to_string(layout_C) + "\n layout_C: " + to_string(layout_C) +
"\n element_accumulator: " + to_string(element_accumulator) +
"\n threadblock_shape: " + threadblock_shape_str + "\n threadblock_shape: " + threadblock_shape_str +
"\n warp_shape: " + warp_shape_str + "\n warp_shape: " + warp_shape_str +
"\n instruction_shape: " + instruction_shape_str + "\n instruction_shape: " + instruction_shape_str +
"\n stages: " + std::to_string(stages) + "\n stages: " + std::to_string(stages) +
"\n alignment_A: " + std::to_string(alignment_A) +
"\n alignment_B: " + std::to_string(alignment_B) +
"\n split_k_mode: " + to_string(split_k_mode) + "\n}"; "\n split_k_mode: " + to_string(split_k_mode) + "\n}";
} }
}; };
@@ -147,6 +156,8 @@ struct GemmKeyHasher {
.update(&key.layout_B, sizeof(key.layout_B)) .update(&key.layout_B, sizeof(key.layout_B))
.update(&key.element_C, sizeof(key.element_C)) .update(&key.element_C, sizeof(key.element_C))
.update(&key.layout_C, sizeof(key.layout_C)) .update(&key.layout_C, sizeof(key.layout_C))
.update(&key.element_accumulator,
sizeof(key.element_accumulator))
.update(&key.threadblock_shape_m, .update(&key.threadblock_shape_m,
sizeof(key.threadblock_shape_m)) sizeof(key.threadblock_shape_m))
.update(&key.threadblock_shape_n, .update(&key.threadblock_shape_n,
@@ -157,6 +168,8 @@ struct GemmKeyHasher {
.update(&key.warp_shape_n, sizeof(key.warp_shape_n)) .update(&key.warp_shape_n, sizeof(key.warp_shape_n))
.update(&key.warp_shape_k, sizeof(key.warp_shape_k)) .update(&key.warp_shape_k, sizeof(key.warp_shape_k))
.update(&key.stages, sizeof(key.stages)) .update(&key.stages, sizeof(key.stages))
.update(&key.alignment_A, sizeof(key.alignment_A))
.update(&key.alignment_B, sizeof(key.alignment_B))
.update(&key.split_k_mode, sizeof(key.split_k_mode)) .update(&key.split_k_mode, sizeof(key.split_k_mode))
.digest(); .digest();
} }


+ 20
- 1
dnn/src/cuda/matrix_mul/algos.cpp View File

@@ -43,6 +43,12 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() {
for (auto&& algo : simt_float32_gemv_batched_strided) { for (auto&& algo : simt_float32_gemv_batched_strided) {
all_algos.push_back(&algo); all_algos.push_back(&algo);
} }
for (auto&& algo : tensorop_float16) {
all_algos.push_back(&algo);
}
for (auto&& algo : tensorop_float16_split_k) {
all_algos.push_back(&algo);
}
#endif #endif
all_algos.push_back(&naive); all_algos.push_back(&naive);


@@ -53,7 +59,7 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() {


#if CUDA_VERSION >= 9020 #if CUDA_VERSION >= 9020
void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() { void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() {
using AlgoParam = AlgoFloat32SIMT::AlgoParam;
using AlgoParam = AlgoCutlassMatrixMulBase::AlgoParam;
simt_float32.emplace_back(AlgoParam{64, 256, 8, 32, 64, 8}); simt_float32.emplace_back(AlgoParam{64, 256, 8, 32, 64, 8});
simt_float32.emplace_back(AlgoParam{256, 64, 8, 64, 32, 8}); simt_float32.emplace_back(AlgoParam{256, 64, 8, 64, 32, 8});
simt_float32.emplace_back(AlgoParam{32, 256, 8, 16, 64, 8}); simt_float32.emplace_back(AlgoParam{32, 256, 8, 16, 64, 8});
@@ -91,6 +97,19 @@ void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() {
simt_float32_gemv_batched_strided.emplace_back(128); simt_float32_gemv_batched_strided.emplace_back(128);
simt_float32_gemv_batched_strided.emplace_back(64); simt_float32_gemv_batched_strided.emplace_back(64);
simt_float32_gemv_batched_strided.emplace_back(32); simt_float32_gemv_batched_strided.emplace_back(32);
#define FOREACH_CUTLASS_MATMUL_F16_SHAPES(cb) \
cb(256, 128, 32, 64, 64, 32, 8, 8, 4); \
cb(128, 256, 32, 64, 64, 32, 8, 8, 4); \
cb(128, 128, 32, 64, 64, 32, 8, 8, 4); \
cb(256, 128, 32, 64, 64, 32, 16, 8, 8); \
cb(128, 256, 32, 64, 64, 32, 16, 8, 8); \
cb(128, 128, 32, 64, 64, 32, 16, 8, 8);
#define cb(...) \
tensorop_float16.emplace_back(AlgoParam{__VA_ARGS__}); \
tensorop_float16_split_k.emplace_back(AlgoParam{__VA_ARGS__});
FOREACH_CUTLASS_MATMUL_F16_SHAPES(cb)
#undef cb
#undef FOREACH_CUTLASS_MATMUL_F16_SHAPES
} }
#endif #endif




+ 101
- 29
dnn/src/cuda/matrix_mul/algos.h View File

@@ -41,11 +41,13 @@ public:
CUDA_WMMA_UINT4X4X32, CUDA_WMMA_UINT4X4X32,
CUDA_CUBLASLT, CUDA_CUBLASLT,
CUDA_NAIVE, CUDA_NAIVE,
CUDA_BFLOAT16,
CUDA_BFLOAT16,
#if CUDA_VERSION >= 9020 #if CUDA_VERSION >= 9020
CUDA_FLOAT32_SIMT,
CUDA_FLOAT32_SIMT_SPLIT_K,
CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED,
CUDA_FLOAT32_SIMT,
CUDA_FLOAT32_SIMT_SPLIT_K,
CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED,
CUDA_FLOAT16_TENSOR_OP,
CUDA_FLOAT16_TENSOR_OP_SPLIT_K,
#endif #endif
}; };
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
@@ -188,65 +190,83 @@ private:
#endif #endif


#if CUDA_VERSION >= 9020 #if CUDA_VERSION >= 9020
class MatrixMulForwardImpl::AlgoFloat32SIMT final : public AlgoBase {
class MatrixMulForwardImpl::AlgoCutlassMatrixMulBase : public AlgoBase {
public: public:
struct AlgoParam { struct AlgoParam {
int threadblock_m, threadblock_n, threadblock_k; int threadblock_m, threadblock_n, threadblock_k;
int warp_m, warp_n, warp_k; int warp_m, warp_n, warp_k;
std::string to_string() {
return ssprintf("%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n,
threadblock_k, warp_m, warp_n, warp_k);
}
int instruction_m, instruction_n, instruction_k;
AlgoParam(int threadblock_m_, int threadblock_n_, int threadblock_k_,
int warp_m_, int warp_n_, int warp_k_, int instruction_m_ = 1,
int instruction_n_ = 1, int instruction_k_ = 1)
: threadblock_m{threadblock_m_},
threadblock_n{threadblock_n_},
threadblock_k{threadblock_k_},
warp_m{warp_m_},
warp_n{warp_n_},
warp_k{warp_k_},
instruction_m{instruction_m_},
instruction_n{instruction_n_},
instruction_k{instruction_k_} {}
std::string to_string() const;
}; };
AlgoCutlassMatrixMulBase(AlgoParam algo_param) : m_algo_param{algo_param} {}
void exec(const ExecArgs& args) const override;
std::string param() const override {
std::string ret;
serialize_write_pod(m_algo_param, ret);
return ret;
}

protected:
virtual int min_alignment_requirement() const = 0;
virtual void do_exec(const ExecArgs& args) const = 0;
std::pair<bool, TensorLayoutArray> construct_aligned_layouts(
const SizeArgs& args) const;
int max_alignment(const SizeArgs& args) const;
AlgoParam m_algo_param;
};

class MatrixMulForwardImpl::AlgoFloat32SIMT final
: public AlgoCutlassMatrixMulBase {
public:
AlgoFloat32SIMT(AlgoParam algo_param) AlgoFloat32SIMT(AlgoParam algo_param)
: m_algo_param{algo_param},
: AlgoCutlassMatrixMulBase{algo_param},
m_name{ssprintf("CUTLASS_FLOAT32_SIMT_%s", m_name{ssprintf("CUTLASS_FLOAT32_SIMT_%s",
m_algo_param.to_string().c_str())} {} m_algo_param.to_string().c_str())} {}
bool is_available(const SizeArgs& args) const override; bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override { return m_name.c_str(); } const char* name() const override { return m_name.c_str(); }
void exec(const ExecArgs& args) const override;
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE; return AlgoAttribute::REPRODUCIBLE;
} }
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT) MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT)


std::string param() const override {
std::string ret;
serialize_write_pod(m_algo_param, ret);
return ret;
}

private: private:
AlgoParam m_algo_param;
void do_exec(const ExecArgs& args) const override;
int min_alignment_requirement() const override { return 1; }
std::string m_name; std::string m_name;
}; };


class MatrixMulForwardImpl::AlgoFloat32SIMTSplitK final : public AlgoBase {
class MatrixMulForwardImpl::AlgoFloat32SIMTSplitK final
: public AlgoCutlassMatrixMulBase {
public: public:
using AlgoParam = MatrixMulForwardImpl::AlgoFloat32SIMT::AlgoParam;
AlgoFloat32SIMTSplitK(AlgoParam algo_param) AlgoFloat32SIMTSplitK(AlgoParam algo_param)
: m_algo_param{algo_param},
: AlgoCutlassMatrixMulBase{algo_param},
m_name{ssprintf("CUTLASS_FLOAT32_SIMT_SPLIT_K_%s", m_name{ssprintf("CUTLASS_FLOAT32_SIMT_SPLIT_K_%s",
m_algo_param.to_string().c_str())} {} m_algo_param.to_string().c_str())} {}
bool is_available(const SizeArgs& args) const override; bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override { return m_name.c_str(); } const char* name() const override { return m_name.c_str(); }
void exec(const ExecArgs& args) const override;
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
} }
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K) MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K)


std::string param() const override {
std::string ret;
serialize_write_pod(m_algo_param, ret);
return ret;
}

private: private:
AlgoParam m_algo_param;
void do_exec(const ExecArgs& args) const override;
int min_alignment_requirement() const override { return 1; }
std::string m_name; std::string m_name;
}; };


@@ -276,6 +296,56 @@ private:
int m_threadblock_n; int m_threadblock_n;
std::string m_name; std::string m_name;
}; };

class MatrixMulForwardImpl::AlgoFloat16TensorOp final
: public AlgoCutlassMatrixMulBase {
public:
AlgoFloat16TensorOp(AlgoParam algo_param)
: AlgoCutlassMatrixMulBase{algo_param},
m_name{ssprintf("CUTLASS_FLOAT16_TENSOR_OP_h%d%d%d_%s",
m_algo_param.instruction_m,
m_algo_param.instruction_n,
m_algo_param.instruction_k,
m_algo_param.to_string().c_str())} {}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT16_TENSOR_OP)

private:
void do_exec(const ExecArgs& args) const override;
int min_alignment_requirement() const override { return 2; }
std::string m_name;
};

class MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK final
: public AlgoCutlassMatrixMulBase {
public:
AlgoFloat16TensorOpSplitK(AlgoParam algo_param)
: AlgoCutlassMatrixMulBase{algo_param},
m_name{ssprintf("CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h%d%d%d_%s",
m_algo_param.instruction_m,
m_algo_param.instruction_n,
m_algo_param.instruction_k,
m_algo_param.to_string().c_str())} {}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT16_TENSOR_OP_SPLIT_K)

private:
void do_exec(const ExecArgs& args) const override;
int min_alignment_requirement() const override { return 2; }
std::string m_name;
};

#endif #endif


class MatrixMulForwardImpl::AlgoPack : NonCopyableObj { class MatrixMulForwardImpl::AlgoPack : NonCopyableObj {
@@ -300,6 +370,8 @@ public:
std::vector<AlgoFloat32SIMTSplitK> simt_float32_split_k; std::vector<AlgoFloat32SIMTSplitK> simt_float32_split_k;
std::vector<AlgoFloat32SIMTGemvBatchedStrided> std::vector<AlgoFloat32SIMTGemvBatchedStrided>
simt_float32_gemv_batched_strided; simt_float32_gemv_batched_strided;
std::vector<AlgoFloat16TensorOp> tensorop_float16;
std::vector<AlgoFloat16TensorOpSplitK> tensorop_float16_split_k;
#endif #endif
std::vector<AlgoBase*> all_algos; std::vector<AlgoBase*> all_algos;




+ 154
- 0
dnn/src/cuda/matrix_mul/cutlass_float16_tensorop.cpp View File

@@ -0,0 +1,154 @@
/**
* \file dnn/src/cuda/matrix_mul/cutlass_float16_tensorop.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/cuda/cutlass/singleton.h"
#include "src/cuda/handle.h"
#include "src/cuda/matrix_mul/algos.h"
#include "src/cuda/utils.h"

#if CUDA_VERSION >= 9020
using namespace megdnn;
using namespace cuda;

bool MatrixMulForwardImpl::AlgoFloat16TensorOp::is_available(
const SizeArgs& args) const {
bool available =
args.opr->param().format == param::MatrixMul::Format::DEFAULT &&
args.layout_b.dtype == dtype::Float16() &&
args.layout_c.dtype == dtype::Float16();
int n = args.layout_c.shape[1];
auto&& device_prop = cuda::current_device_prop();
int y_grid_limit = device_prop.maxGridSize[1];
// limit y grid
available &= ((n + m_algo_param.threadblock_n - 1) /
m_algo_param.threadblock_n <=
y_grid_limit);
if (m_algo_param.instruction_m == 8 && m_algo_param.instruction_n == 8 &&
m_algo_param.instruction_k == 4) {
available &= is_compute_capability_required(7, 0);
} else {
megdnn_assert(m_algo_param.instruction_m == 16 &&
m_algo_param.instruction_n == 8 &&
m_algo_param.instruction_k == 8);
available &= is_compute_capability_required(7, 5);
}

return available;
}

size_t MatrixMulForwardImpl::AlgoFloat16TensorOp::get_workspace_in_bytes(
const SizeArgs& args) const {
auto aligned = construct_aligned_layouts(args);
if (!aligned.first)
return 0_z;
const auto& layouts = aligned.second;
size_t ws_size = 0;
for (auto&& ly : layouts) {
ws_size += ly.span().dist_byte();
}
return ws_size;
}

void MatrixMulForwardImpl::AlgoFloat16TensorOp::do_exec(
const ExecArgs& args) const {
int64_t lda = args.tensor_a.layout.stride[0],
ldb = args.tensor_b.layout.stride[0],
ldc = args.tensor_c.layout.stride[0];
int alignment = max_alignment(args);
int min_alignment = min_alignment_requirement();
auto&& param = args.opr->param();
int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1],
k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1];
megdnn_assert(lda % alignment == 0 && ldb % alignment == 0 &&
ldc % alignment == 0 && m % alignment == 0 &&
n % alignment == 0 && k % alignment == 0 &&
alignment >= min_alignment);
cutlass::gemm::GemmCoord problem_size{m, n, k};
auto&& stream = cuda_stream(args.opr->handle());
int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr);
// \note these constants (i.e. one and zero) of cutlass epilogue will be
// passed by pointers and interpreted as ElementCompute*, which will be used
// to initialize kernel parameters. So the arguments' type on the host side
// should be the same as the ElementCompute of kernel instance, otherwise
// undefined kernel bahaviors will occur caused by incorrect intepretation
// of these pointers.
float one = 1.f, zero = 0.f;
dt_float16 one_f16 = static_cast<dt_float16>(one),
zero_f16 = static_cast<dt_float16>(zero);

using namespace cutlass::library;

auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor
: LayoutTypeID::kRowMajor;
auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor
: LayoutTypeID::kRowMajor;

void *host_one, *host_zero;
NumericTypeID element_accumulator;
if (param.compute_mode == param::MatrixMul::ComputeMode::DEFAULT) {
element_accumulator = NumericTypeID::kF16;
host_one = &one_f16;
host_zero = &zero_f16;
} else {
megdnn_assert(param.compute_mode ==
param::MatrixMul::ComputeMode::FLOAT32);
element_accumulator = NumericTypeID::kF32;
host_one = &one;
host_zero = &zero;
}

GemmKey key{NumericTypeID::kF16,
layoutA,
NumericTypeID::kF16,
layoutB,
NumericTypeID::kF16,
LayoutTypeID::kRowMajor,
element_accumulator,
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k,
m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k,
m_algo_param.instruction_m,
m_algo_param.instruction_n,
m_algo_param.instruction_k,
2,
alignment,
alignment,
SplitKMode::kNone};

const auto& table = Singleton::get().operation_table;
megdnn_assert(table.gemm_operations.count(key) > 0,
"key not found in cutlass operation table");
const auto& ops = table.gemm_operations.at(key);
megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu",
ops.size());

GemmArguments gemm_args{problem_size,
args.tensor_a.raw_ptr,
args.tensor_b.raw_ptr,
args.tensor_c.raw_ptr,
args.tensor_c.raw_ptr,
lda,
ldb,
ldc,
ldc,
1,
host_one,
host_zero};

cutlass_check(ops[0]->run(&gemm_args, workspace, stream));
}
#endif

// vim: syntax=cpp.doxygen

+ 165
- 0
dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp View File

@@ -0,0 +1,165 @@
/**
* \file dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/cuda/cutlass/singleton.h"
#include "src/cuda/handle.h"
#include "src/cuda/matrix_mul/algos.h"
#include "src/cuda/utils.h"

#if CUDA_VERSION >= 9020
using namespace megdnn;
using namespace cuda;

bool MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::is_available(
const SizeArgs& args) const {
auto&& param = args.opr->param();
int n = args.layout_c.shape[1],
k = args.layout_a.shape[param.transposeA ? 0 : 1];
bool available =
args.opr->param().format == param::MatrixMul::Format::DEFAULT &&
args.layout_a.dtype == dtype::Float16() &&
args.layout_b.dtype == dtype::Float16() &&
args.layout_c.dtype == dtype::Float16() && k > n;
auto&& device_prop = cuda::current_device_prop();
int y_grid_limit = device_prop.maxGridSize[1];
// limit y grid
available &= ((n + m_algo_param.threadblock_n - 1) /
m_algo_param.threadblock_n <=
y_grid_limit);
if (m_algo_param.instruction_m == 8 && m_algo_param.instruction_n == 8 &&
m_algo_param.instruction_k == 4) {
available &= is_compute_capability_required(7, 0);
} else {
megdnn_assert(m_algo_param.instruction_m == 16 &&
m_algo_param.instruction_n == 8 &&
m_algo_param.instruction_k == 8);
available &= is_compute_capability_required(7, 5);
}

return available;
}

size_t MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::get_workspace_in_bytes(
const SizeArgs& args) const {
auto aligned = construct_aligned_layouts(args);
auto&& param = args.opr->param();
int m = args.layout_c.shape[0], n = args.layout_c.shape[1],
k = args.layout_a.shape[param.transposeA ? 0 : 1];
int split_k_slices = std::max(1, k / n);
if (!aligned.first)
return args.layout_c.dtype.size(m * n * split_k_slices);
const auto& layouts = aligned.second;
int align_m = layouts[2].shape[0], align_n = layouts[2].shape[1],
align_k = layouts[0].shape[1];
split_k_slices = std::max(1, align_k / align_n);
size_t ws_size =
args.layout_c.dtype.size(align_m * align_n * split_k_slices);
for (auto&& ly : layouts)
ws_size += ly.span().dist_byte();
return ws_size;
}

void MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::do_exec(
const ExecArgs& args) const {
int64_t lda = args.tensor_a.layout.stride[0],
ldb = args.tensor_b.layout.stride[0],
ldc = args.tensor_c.layout.stride[0];
int alignment = max_alignment(args);
int min_alignment = min_alignment_requirement();
auto&& param = args.opr->param();
int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1],
k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1];
megdnn_assert(lda % alignment == 0 && ldb % alignment == 0 &&
ldc % alignment == 0 && m % alignment == 0 &&
n % alignment == 0 && k % alignment == 0 &&
alignment >= min_alignment);
cutlass::gemm::GemmCoord problem_size{m, n, k};
int split_k_slices = std::max(1, k / n);
auto&& stream = cuda_stream(args.opr->handle());
int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr);
// \note these constants (i.e. one and zero) of cutlass epilogue will be
// passed by pointers and interpreted as ElementCompute*, which will be used
// to initialize kernel parameters. So the arguments' type on the host side
// should be the same as the ElementCompute of kernel instance, otherwise
// undefined kernel bahaviors will occur caused by incorrect intepretation
// of these pointers.
float one = 1.f, zero = 0.f;
dt_float16 one_f16 = static_cast<dt_float16>(one),
zero_f16 = static_cast<dt_float16>(zero);

using namespace cutlass::library;

auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor
: LayoutTypeID::kRowMajor;
auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor
: LayoutTypeID::kRowMajor;

void *host_one, *host_zero;
NumericTypeID element_accumulator;
if (param.compute_mode == param::MatrixMul::ComputeMode::DEFAULT) {
element_accumulator = NumericTypeID::kF16;
host_one = &one_f16;
host_zero = &zero_f16;
} else {
megdnn_assert(param.compute_mode ==
param::MatrixMul::ComputeMode::FLOAT32);
element_accumulator = NumericTypeID::kF32;
host_one = &one;
host_zero = &zero;
}

GemmKey key{NumericTypeID::kF16,
layoutA,
NumericTypeID::kF16,
layoutB,
NumericTypeID::kF16,
LayoutTypeID::kRowMajor,
element_accumulator,
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k,
m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k,
m_algo_param.instruction_m,
m_algo_param.instruction_n,
m_algo_param.instruction_k,
2,
alignment,
alignment,
SplitKMode::kParallel};

const auto& table = Singleton::get().operation_table;
megdnn_assert(table.gemm_operations.count(key) > 0,
"key not found in cutlass operation table");
const auto& ops = table.gemm_operations.at(key);
megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu",
ops.size());

GemmArguments gemm_args{problem_size,
args.tensor_a.raw_ptr,
args.tensor_b.raw_ptr,
args.tensor_c.raw_ptr,
args.tensor_c.raw_ptr,
lda,
ldb,
ldc,
ldc,
split_k_slices,
host_one,
host_zero};

cutlass_check(ops[0]->run(&gemm_args, workspace, stream));
}
#endif

// vim: syntax=cpp.doxygen

+ 8
- 3
dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp View File

@@ -42,7 +42,8 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes(
return 0_z; return 0_z;
} }


void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const {
void MatrixMulForwardImpl::AlgoFloat32SIMT::do_exec(
const ExecArgs& args) const {
int64_t lda = args.tensor_a.layout.stride[0], int64_t lda = args.tensor_a.layout.stride[0],
ldb = args.tensor_b.layout.stride[0], ldb = args.tensor_b.layout.stride[0],
ldc = args.tensor_c.layout.stride[0]; ldc = args.tensor_c.layout.stride[0];
@@ -65,12 +66,14 @@ void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const {
auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor
: LayoutTypeID::kRowMajor; : LayoutTypeID::kRowMajor;


int alignment = min_alignment_requirement();
GemmKey key{NumericTypeID::kF32, GemmKey key{NumericTypeID::kF32,
layoutA, layoutA,
NumericTypeID::kF32, NumericTypeID::kF32,
layoutB, layoutB,
NumericTypeID::kF32, NumericTypeID::kF32,
LayoutTypeID::kRowMajor, LayoutTypeID::kRowMajor,
NumericTypeID::kF32,
m_algo_param.threadblock_m, m_algo_param.threadblock_m,
m_algo_param.threadblock_n, m_algo_param.threadblock_n,
m_algo_param.threadblock_k, m_algo_param.threadblock_k,
@@ -79,8 +82,10 @@ void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const {
m_algo_param.warp_k, m_algo_param.warp_k,
1, 1,
1, 1,
1,
2,
1,
2,
alignment,
alignment,
SplitKMode::kNone}; SplitKMode::kNone};


const Operation* op = Singleton::get().operation_table.find_op(key); const Operation* op = Singleton::get().operation_table.find_op(key);


+ 9
- 5
dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp View File

@@ -22,7 +22,7 @@ using namespace cuda;
bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
auto&& param = args.opr->param(); auto&& param = args.opr->param();
int m = args.layout_c.shape[0], n = args.layout_c.shape[1],
int n = args.layout_c.shape[1],
k = args.layout_a.shape[param.transposeA ? 0 : 1]; k = args.layout_a.shape[param.transposeA ? 0 : 1];
bool available = bool available =
args.opr->param().format == param::MatrixMul::Format::DEFAULT && args.opr->param().format == param::MatrixMul::Format::DEFAULT &&
@@ -32,8 +32,8 @@ bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available(
auto&& device_prop = cuda::current_device_prop(); auto&& device_prop = cuda::current_device_prop();
int y_grid_limit = device_prop.maxGridSize[1]; int y_grid_limit = device_prop.maxGridSize[1];
// limit y grid // limit y grid
available &= ((m + m_algo_param.threadblock_m - 1) /
m_algo_param.threadblock_m <=
available &= ((n + m_algo_param.threadblock_n - 1) /
m_algo_param.threadblock_n <=
y_grid_limit); y_grid_limit);
return available; return available;
} }
@@ -47,7 +47,7 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes(
return args.layout_c.dtype.size(m * n * split_k_slices); return args.layout_c.dtype.size(m * n * split_k_slices);
} }


void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec(
void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::do_exec(
const ExecArgs& args) const { const ExecArgs& args) const {
int64_t lda = args.tensor_a.layout.stride[0], int64_t lda = args.tensor_a.layout.stride[0],
ldb = args.tensor_b.layout.stride[0], ldb = args.tensor_b.layout.stride[0],
@@ -72,12 +72,14 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec(
auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor
: LayoutTypeID::kRowMajor; : LayoutTypeID::kRowMajor;


int alignment = min_alignment_requirement();
GemmKey key{NumericTypeID::kF32, GemmKey key{NumericTypeID::kF32,
layoutA, layoutA,
NumericTypeID::kF32, NumericTypeID::kF32,
layoutB, layoutB,
NumericTypeID::kF32, NumericTypeID::kF32,
LayoutTypeID::kRowMajor, LayoutTypeID::kRowMajor,
NumericTypeID::kF32,
m_algo_param.threadblock_m, m_algo_param.threadblock_m,
m_algo_param.threadblock_n, m_algo_param.threadblock_n,
m_algo_param.threadblock_k, m_algo_param.threadblock_k,
@@ -87,7 +89,9 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec(
1, 1,
1, 1,
1, 1,
2,
2,
alignment,
alignment,
SplitKMode::kParallel}; SplitKMode::kParallel};


Operation const* op = Singleton::get().operation_table.find_op(key); Operation const* op = Singleton::get().operation_table.find_op(key);


+ 136
- 0
dnn/src/cuda/matrix_mul/cutlass_matrix_mul_base.cpp View File

@@ -0,0 +1,136 @@
/**
* \file dnn/src/cuda/matrix_mul/cutlass_matrix_mul_base.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/cuda/handle.h"
#include "src/cuda/matrix_mul/algos.h"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"
#include "src/cuda/utils.h"

#if CUDA_VERSION >= 9020
using namespace megdnn;
using namespace cuda;

std::string
MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::AlgoParam::to_string() const {
return ssprintf("%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n,
threadblock_k, warp_m, warp_n, warp_k);
}

std::pair<bool, TensorLayoutArray>
MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::construct_aligned_layouts(
const SizeArgs& args) const {
int alignment = max_alignment(args);
int min_alignment = min_alignment_requirement();
bool aligned = alignment >= min_alignment;
if (aligned)
return std::make_pair(!aligned, TensorLayoutArray{{}});
auto&& param = args.opr->param();
int m = args.layout_c.shape[0], n = args.layout_c.shape[1],
k = args.layout_a.shape[param.transposeA ? 0 : 1];
size_t align_m = get_aligned_power2(m, min_alignment);
size_t align_n = get_aligned_power2(n, min_alignment);
size_t align_k = get_aligned_power2(k, min_alignment);
TensorLayoutArray layouts;
layouts.emplace_back(TensorLayout{{align_m, align_k}, args.layout_a.dtype});
layouts.emplace_back(TensorLayout{{align_k, align_n}, args.layout_b.dtype});
layouts.emplace_back(TensorLayout{{align_m, align_n}, args.layout_c.dtype});
return std::make_pair(!aligned, std::move(layouts));
}

void MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::exec(
const ExecArgs& args) const {
auto aligned = construct_aligned_layouts(args);
if (!aligned.first)
return do_exec(args);
const auto& layouts = aligned.second;
auto tensor_a = args.tensor_a;
auto tensor_b = args.tensor_b;
auto workspace = args.workspace;
size_t copy_size = 0;
for (const auto& ly : layouts)
copy_size += ly.span().dist_byte();
auto&& param = args.opr->param();
auto&& stream = cuda_stream(args.opr->handle());

cuda_check(cudaMemsetAsync(workspace.raw_ptr, 0, copy_size, stream));

auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>();

auto copy_stride = [](const TensorLayout& src, TensorLayout& dst,
bool trans) {
dst.stride[0] = src.stride[0], dst.stride[1] = src.stride[1];
if (trans)
std::swap(dst.stride[0], dst.stride[1]);
};
copy_stride(layouts[0], tensor_a.layout, param.transposeA);
tensor_a.raw_ptr = workspace.raw_ptr;
relayout->exec(args.tensor_a, tensor_a);
workspace.raw_ptr += layouts[0].span().dist_byte();
workspace.size -= layouts[0].span().dist_byte();

copy_stride(layouts[1], tensor_b.layout, param.transposeB);
tensor_b.raw_ptr = workspace.raw_ptr;
relayout->exec(args.tensor_b, tensor_b);
workspace.raw_ptr += layouts[1].span().dist_byte();
workspace.size -= layouts[1].span().dist_byte();

decltype(tensor_a) tensor_c{workspace.raw_ptr, layouts[2]};
workspace.raw_ptr += layouts[2].span().dist_byte();
workspace.size -= layouts[2].span().dist_byte();

auto&& matmul = args.opr->handle()->create_operator<MatrixMulForward>();
matmul->param().transposeA = false;
matmul->param().transposeB = false;
matmul->param().compute_mode = args.opr->param().compute_mode;

tensor_a.layout = layouts[0];
tensor_b.layout = layouts[1];
ExecArgs args_{static_cast<MatrixMulForwardImpl*>(matmul.get()), tensor_a,
tensor_b, tensor_c, workspace};
do_exec(args_);

tensor_c.layout.TensorShape::operator=(args.layout_c);
relayout->exec(tensor_c, args.tensor_c);
}

int MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::max_alignment(
const SizeArgs& args) const {
auto&& dtype_a = args.layout_a.dtype;
auto&& dtype_b = args.layout_b.dtype;
auto&& dtype_c = args.layout_c.dtype;
auto get_alignment = [](const DType& dt, int len) {
int size_bits = dt.size(1) * 8;
int align = 128;
while (align > 1) {
if ((len * size_bits) % align == 0)
break;
align = align / 2;
}
return align / size_bits;
};
int lda = args.layout_a.stride[0], ldb = args.layout_b.stride[0],
ldc = args.layout_c.stride[0];
auto&& param = args.opr->param();
int m = args.layout_c.shape[0], n = args.layout_c.shape[1],
k = args.layout_a.shape[param.transposeA ? 0 : 1];
int max_align = get_alignment(dtype_a, lda);
max_align = std::min(get_alignment(dtype_a, m), max_align);
max_align = std::min(get_alignment(dtype_a, n), max_align);
max_align = std::min(get_alignment(dtype_a, k), max_align);
max_align = std::min(get_alignment(dtype_a, lda), max_align);
max_align = std::min(get_alignment(dtype_b, ldb), max_align);
max_align = std::min(get_alignment(dtype_c, ldc), max_align);
return max_align;
}
#endif

// vim: syntax=cpp.doxygen

+ 3
- 0
dnn/src/cuda/matrix_mul/opr_impl.h View File

@@ -42,9 +42,12 @@ public:
class AlgoBFloat16; class AlgoBFloat16;
#endif #endif
#if CUDA_VERSION >= 9020 #if CUDA_VERSION >= 9020
class AlgoCutlassMatrixMulBase;
class AlgoFloat32SIMT; class AlgoFloat32SIMT;
class AlgoFloat32SIMTSplitK; class AlgoFloat32SIMTSplitK;
class AlgoFloat32SIMTGemvBatchedStrided; class AlgoFloat32SIMTGemvBatchedStrided;
class AlgoFloat16TensorOp;
class AlgoFloat16TensorOpSplitK;
#endif #endif
class AlgoPack; class AlgoPack;




+ 3
- 1
dnn/test/common/matrix_mul.cpp View File

@@ -184,7 +184,8 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype,
const ExecutionPolicyAlgoName& algo, const ExecutionPolicyAlgoName& algo,
param::MatrixMul::Format format, size_t nbase, param::MatrixMul::Format format, size_t nbase,
float eps, std::vector<TestArg>&& user_args, float eps, std::vector<TestArg>&& user_args,
bool force_deduce_dst) {
bool force_deduce_dst,
param::MatrixMul::ComputeMode compute_mode) {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv()); megdnn_assert(A_dtype.enumv() == B_dtype.enumv());
Checker<Opr> checker(handle); Checker<Opr> checker(handle);
checker.set_force_deduce_dst(force_deduce_dst); checker.set_force_deduce_dst(force_deduce_dst);
@@ -261,6 +262,7 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype,
Param param; Param param;
param.transposeA = arg.mask & 0x1; param.transposeA = arg.mask & 0x1;
param.transposeB = arg.mask & 0x2; param.transposeB = arg.mask & 0x2;
param.compute_mode = compute_mode;
param.format = format; param.format = format;
checker.set_dtype(0, A_dtype) checker.set_dtype(0, A_dtype)
.set_dtype(1, B_dtype) .set_dtype(1, B_dtype)


+ 3
- 1
dnn/test/common/matrix_mul.h View File

@@ -69,7 +69,9 @@ void check_matrix_mul(
const ExecutionPolicyAlgoName& algo = {"", {}}, const ExecutionPolicyAlgoName& algo = {"", {}},
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT,
size_t nbase = 8, float eps = 1e-3, std::vector<TestArg>&& args = {}, size_t nbase = 8, float eps = 1e-3, std::vector<TestArg>&& args = {},
bool force_deduce_dst = true);
bool force_deduce_dst = true,
param::MatrixMul::ComputeMode compute_mode =
param::MatrixMul::ComputeMode::DEFAULT);


void check_matrix_mul( void check_matrix_mul(
DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle,


+ 91
- 0
dnn/test/cuda/cutlass_matmul.cpp View File

@@ -21,6 +21,7 @@
#include "test/cuda/fixture.h" #include "test/cuda/fixture.h"
#include "test/cuda/utils.h" #include "test/cuda/utils.h"


#define MEGDNN_WITH_BENCHMARK 1
#if CUDA_VERSION >= 9020 #if CUDA_VERSION >= 9020
namespace megdnn { namespace megdnn {
namespace test { namespace test {
@@ -215,6 +216,14 @@ std::vector<BenchArgs> get_feat_model_args() {
return args; return args;
} }


std::vector<BenchArgs> get_f16_feat_model_args() {
std::vector<BenchArgs> args;
args.emplace_back(BenchArgs{128, 9216, 9216});
args.emplace_back(BenchArgs{128, 6400, 6400});
args.emplace_back(BenchArgs{128, 5184, 5184});
return args;
}

void benchmark_matrix_mul( void benchmark_matrix_mul(
Handle* handle, const std::vector<BenchArgs>& args, DType A_dtype, Handle* handle, const std::vector<BenchArgs>& args, DType A_dtype,
DType B_dtype, DType C_dtype, const char* algo = nullptr, DType B_dtype, DType C_dtype, const char* algo = nullptr,
@@ -364,6 +373,82 @@ MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
#undef cb #undef cb
#undef MEGDNN_FOREACH_CUTLASS_KERNEL #undef MEGDNN_FOREACH_CUTLASS_KERNEL


#define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \
cb(1, 256, 128, 32, 64, 64, 32, 8, 8, 4); \
cb(2, 128, 256, 32, 64, 64, 32, 8, 8, 4); \
cb(3, 128, 128, 32, 64, 64, 32, 8, 8, 4);

#define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \
TEST_F(CUDA, CUTLASS_F16_884_GEMM_##name) { \
require_compute_capability(7, 0); \
matrix_mul::check_matrix_mul<MatrixMulForward>( \
dtype::Float16(), dtype::Float16(), dtype::Float16(), \
handle_cuda(), \
"CUTLASS_FLOAT16_TENSOR_OP_h" #im #in #ik "_" #tbm "X" #tbn \
"X" #tbk "_" #wm "X" #wn "X" #wk, \
param::MatrixMul::Format::DEFAULT, 8, 1e-2, \
matrix_mul::get_matmul_args()); \
}
MEGDNN_FOREACH_CUTLASS_KERNEL(cb)

#undef cb

#define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \
TEST_F(CUDA, CUTLASS_F16_884_GEMM_SPLIT_K_##name) { \
require_compute_capability(7, 0); \
matrix_mul::check_matrix_mul<MatrixMulForward>( \
dtype::Float16(), dtype::Float16(), dtype::Float16(), \
handle_cuda(), \
"CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h" #im #in #ik "_" #tbm \
"X" #tbn "X" #tbk "_" #wm "X" #wn "X" #wk, \
param::MatrixMul::Format::DEFAULT, 8, 1e-3, \
matrix_mul::get_matmul_args_split_k(), true, \
param::MatrixMul::ComputeMode::FLOAT32); \
}
MEGDNN_FOREACH_CUTLASS_KERNEL(cb)

#undef cb

#undef MEGDNN_FOREACH_CUTLASS_KERNEL

#define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \
cb(1, 256, 128, 32, 64, 64, 32, 16, 8, 8); \
cb(2, 128, 256, 32, 64, 64, 32, 16, 8, 8); \
cb(3, 128, 128, 32, 64, 64, 32, 16, 8, 8);

#define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \
TEST_F(CUDA, CUTLASS_F16_1688_GEMM_##name) { \
require_compute_capability(7, 5); \
matrix_mul::check_matrix_mul<MatrixMulForward>( \
dtype::Float16(), dtype::Float16(), dtype::Float16(), \
handle_cuda(), \
"CUTLASS_FLOAT16_TENSOR_OP_h" #im #in #ik "_" #tbm "X" #tbn \
"X" #tbk "_" #wm "X" #wn "X" #wk, \
param::MatrixMul::Format::DEFAULT, 8, 1e-2, \
matrix_mul::get_matmul_args(), true, \
param::MatrixMul::ComputeMode::FLOAT32); \
}
MEGDNN_FOREACH_CUTLASS_KERNEL(cb)

#undef cb

#define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \
TEST_F(CUDA, CUTLASS_F16_1688_GEMM_SPLIT_K_##name) { \
require_compute_capability(7, 5); \
matrix_mul::check_matrix_mul<MatrixMulForward>( \
dtype::Float16(), dtype::Float16(), dtype::Float16(), \
handle_cuda(), \
"CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h" #im #in #ik "_" #tbm \
"X" #tbn "X" #tbk "_" #wm "X" #wn "X" #wk, \
param::MatrixMul::Format::DEFAULT, 8, 1e-3, \
matrix_mul::get_matmul_args_split_k()); \
}
MEGDNN_FOREACH_CUTLASS_KERNEL(cb)

#undef cb

#undef MEGDNN_FOREACH_CUTLASS_KERNEL

#if MEGDNN_WITH_BENCHMARK #if MEGDNN_WITH_BENCHMARK
TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL) { TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL) {
benchmark_matrix_mul(handle_cuda(), get_square_matmul_args(), benchmark_matrix_mul(handle_cuda(), get_square_matmul_args(),
@@ -376,6 +461,12 @@ TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL_FEAT) {
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32(),
"CUTLASS_FLOAT32_SIMT"); "CUTLASS_FLOAT32_SIMT");
} }

TEST_F(CUDA, BENCHMARK_CUTLASS_F16_MATMUL_FEAT) {
benchmark_matrix_mul(handle_cuda(), get_f16_feat_model_args(),
dtype::Float16(), dtype::Float16(), dtype::Float16(),
"CUTLASS_FLOAT16_TENSOR_OP");
}
#endif #endif
} // namespace test } // namespace test
} // namespace megdnn } // namespace megdnn


Loading…
Cancel
Save