GitOrigin-RevId: 2a70335441
tags/v1.6.0-rc1
@@ -163,7 +163,7 @@ using Convolution = | |||||
${element_bias}, | ${element_bias}, | ||||
${layout_bias}, | ${layout_bias}, | ||||
${element_accumulator}, | ${element_accumulator}, | ||||
${conv_type}, | |||||
${conv_type}, | |||||
${opcode_class}, | ${opcode_class}, | ||||
${arch}, | ${arch}, | ||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, | cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, | ||||
@@ -246,6 +246,7 @@ using Deconvolution = | |||||
${element_bias}, | ${element_bias}, | ||||
${layout_bias}, | ${layout_bias}, | ||||
${element_accumulator}, | ${element_accumulator}, | ||||
${conv_type}, | |||||
${opcode_class}, | ${opcode_class}, | ||||
${arch}, | ${arch}, | ||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, | cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, | ||||
@@ -276,6 +277,7 @@ using Deconvolution = | |||||
values = { | values = { | ||||
'operation_name': operation.procedural_name(), | 'operation_name': operation.procedural_name(), | ||||
'conv_type': ConvTypeTag[operation.conv_type], | |||||
'element_src': DataTypeTag[operation.src.element], | 'element_src': DataTypeTag[operation.src.element], | ||||
'layout_src': LayoutTag[operation.src.layout], | 'layout_src': LayoutTag[operation.src.layout], | ||||
'element_flt': DataTypeTag[operation.flt.element], | 'element_flt': DataTypeTag[operation.flt.element], | ||||
@@ -530,44 +532,17 @@ void initialize_${configuration_name}(Manifest &manifest) { | |||||
################################################################################################### | ################################################################################################### | ||||
class EmitConvSingleKernelWrapper(): | class EmitConvSingleKernelWrapper(): | ||||
def __init__(self, kernel_path, operation, wrapper_path): | |||||
def __init__(self, kernel_path, operation): | |||||
self.kernel_path = kernel_path | self.kernel_path = kernel_path | ||||
self.wrapper_path = wrapper_path | |||||
self.operation = operation | self.operation = operation | ||||
self.conv_wrappers = { \ | |||||
ConvKind.Fprop: """ | |||||
template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
const typename Convolution::ElementSrc* d_src, | |||||
const typename Convolution::ElementFilter* d_filter, | |||||
const typename Convolution::ElementBias* d_bias, | |||||
const typename Convolution::ElementDst* d_z, | |||||
typename Convolution::ElementDst* d_dst, | |||||
int* workspace, | |||||
typename Convolution::ConvolutionParameter const& conv_param, | |||||
typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
cudaStream_t stream, | |||||
typename Convolution::ExtraParam extra_param); | |||||
""", \ | |||||
ConvKind.Dgrad: """ | |||||
template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper<Deconvolution>( | |||||
const typename Deconvolution::ElementSrc* d_src, | |||||
const typename Deconvolution::ElementFilter* d_filter, | |||||
const typename Deconvolution::ElementBias* d_bias, | |||||
const typename Deconvolution::ElementDst* d_z, | |||||
typename Deconvolution::ElementDst* d_dst, | |||||
int* workspace, | |||||
typename Deconvolution::ConvolutionParameter const& conv_param, | |||||
typename Deconvolution::EpilogueOutputOp::Params const& epilogue, | |||||
cudaStream_t stream); | |||||
""", \ | |||||
} | |||||
if self.operation.conv_kind == ConvKind.Fprop: | if self.operation.conv_kind == ConvKind.Fprop: | ||||
self.instance_emitter = EmitConv2dInstance() | self.instance_emitter = EmitConv2dInstance() | ||||
self.convolution_name = "Convolution" | |||||
else: | else: | ||||
assert self.operation.conv_kind == ConvKind.Dgrad | assert self.operation.conv_kind == ConvKind.Dgrad | ||||
self.instance_emitter = EmitDeconvInstance() | self.instance_emitter = EmitDeconvInstance() | ||||
self.convolution_name = "Deconvolution" | |||||
self.header_template = """ | self.header_template = """ | ||||
#if !MEGDNN_TEGRA_X1 | #if !MEGDNN_TEGRA_X1 | ||||
@@ -575,13 +550,30 @@ template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper<Decon | |||||
#pragma GCC diagnostic push | #pragma GCC diagnostic push | ||||
#pragma GCC diagnostic ignored "-Wunused-parameter" | #pragma GCC diagnostic ignored "-Wunused-parameter" | ||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing" | #pragma GCC diagnostic ignored "-Wstrict-aliasing" | ||||
#include "${wrapper_path}" | |||||
#pragma GCC diagnostic ignored "-Wuninitialized" | |||||
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" | |||||
#include "cutlass/convolution/device/convolution.h" | |||||
#include "src/cuda/cutlass/manifest.h" | |||||
#include "src/cuda/cutlass/convolution_operation.h" | |||||
""" | """ | ||||
self.instance_template = """ | self.instance_template = """ | ||||
${operation_instance} | ${operation_instance} | ||||
""" | """ | ||||
self.wrapper_template = """ | |||||
${wrapper_instance} | |||||
self.manifest_template = """ | |||||
namespace cutlass { | |||||
namespace library { | |||||
void initialize_${operation_name}(Manifest &manifest) { | |||||
manifest.append(new ConvolutionOperation<${convolution_name}>( | |||||
"${operation_name}" | |||||
)); | |||||
} | |||||
} // namespace library | |||||
} // namespace cutlass | |||||
""" | """ | ||||
self.epilogue_template = """ | self.epilogue_template = """ | ||||
@@ -593,9 +585,7 @@ ${wrapper_instance} | |||||
def __enter__(self): | def __enter__(self): | ||||
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) | self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) | ||||
self.kernel_file = LazyFile(self.kernel_path) | self.kernel_file = LazyFile(self.kernel_path) | ||||
self.kernel_file.write(SubstituteTemplate(self.header_template, { | |||||
'wrapper_path': self.wrapper_path, | |||||
})) | |||||
self.kernel_file.write(self.header_template) | |||||
return self | return self | ||||
# | # | ||||
@@ -604,11 +594,12 @@ ${wrapper_instance} | |||||
'operation_instance': self.instance_emitter.emit(self.operation), | 'operation_instance': self.instance_emitter.emit(self.operation), | ||||
})) | })) | ||||
# emit wrapper | |||||
wrapper = SubstituteTemplate(self.wrapper_template, { | |||||
'wrapper_instance': self.conv_wrappers[self.operation.conv_kind], | |||||
# emit manifest helper | |||||
manifest = SubstituteTemplate(self.manifest_template, { | |||||
'operation_name': self.operation.procedural_name(), | |||||
'convolution_name': self.convolution_name | |||||
}) | }) | ||||
self.kernel_file.write(wrapper) | |||||
self.kernel_file.write(manifest) | |||||
# | # | ||||
def __exit__(self, exception_type, exception_value, traceback): | def __exit__(self, exception_type, exception_value, traceback): | ||||
@@ -940,8 +940,8 @@ void initialize_${configuration_name}(Manifest &manifest) { | |||||
/////////////////////////////////////////////////////////////////////////////////////////////////// | /////////////////////////////////////////////////////////////////////////////////////////////////// | ||||
} // namespace library | |||||
} // namespace cutlass | |||||
} // namespace library | |||||
} // namespace cutlass | |||||
/////////////////////////////////////////////////////////////////////////////////////////////////// | /////////////////////////////////////////////////////////////////////////////////////////////////// | ||||
@@ -995,48 +995,101 @@ void initialize_${configuration_name}(Manifest &manifest) { | |||||
################################################################################################### | ################################################################################################### | ||||
class EmitGemmSingleKernelWrapper: | class EmitGemmSingleKernelWrapper: | ||||
def __init__(self, kernel_path, gemm_operation, wrapper_path): | |||||
def __init__(self, kernel_path, gemm_operation): | |||||
self.kernel_path = kernel_path | self.kernel_path = kernel_path | ||||
self.wrapper_path = wrapper_path | |||||
self.operation = gemm_operation | self.operation = gemm_operation | ||||
gemm_wrapper = """ | |||||
template void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_wrapper<Operation_${operation_name}>( | |||||
const typename Operation_${operation_name}::ElementA* d_A, size_t lda, | |||||
const typename Operation_${operation_name}::ElementB* d_B, size_t ldb, | |||||
typename Operation_${operation_name}::ElementC* d_C, size_t ldc, | |||||
int* workspace, | |||||
cutlass::gemm::GemmCoord const& problem_size, | |||||
typename Operation_${operation_name}::EpilogueOutputOp::Params const& epilogue, | |||||
cudaStream_t stream, int split_k_slices); | |||||
instance_emitters = { | |||||
GemmKind.Gemm: EmitGemmInstance(), | |||||
GemmKind.SplitKParallel: EmitGemmSplitKParallelInstance(), | |||||
} | |||||
self.instance_emitter = instance_emitters[self.operation.gemm_kind] | |||||
self.header_template = """ | |||||
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
// ignore warning of cutlass | |||||
#pragma GCC diagnostic push | |||||
#pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
#pragma GCC diagnostic ignored "-Wuninitialized" | |||||
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" | |||||
#include "cutlass/gemm/device/gemm.h" | |||||
#include "cutlass/gemm/device/gemm_splitk_parallel.h" | |||||
#include "src/cuda/cutlass/manifest.h" | |||||
#include "src/cuda/cutlass/gemm_operation.h" | |||||
""" | """ | ||||
self.instance_template = """ | |||||
${operation_instance} | |||||
""" | |||||
self.manifest_template = """ | |||||
namespace cutlass { | |||||
namespace library { | |||||
void initialize_${operation_name}(Manifest &manifest) { | |||||
manifest.append(new GemmOperation< | |||||
Operation_${operation_name} | |||||
>("${operation_name}")); | |||||
} | |||||
} // namespace library | |||||
} // namespace cutlass | |||||
""" | |||||
self.epilogue_template = """ | |||||
#pragma GCC diagnostic pop | |||||
#endif | |||||
""" | |||||
# | |||||
def __enter__(self): | |||||
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) | |||||
self.kernel_file = LazyFile(self.kernel_path) | |||||
self.kernel_file.write(self.header_template) | |||||
return self | |||||
# | |||||
def emit(self): | |||||
self.kernel_file.write(SubstituteTemplate(self.instance_template, { | |||||
'operation_instance': self.instance_emitter.emit(self.operation), | |||||
})) | |||||
gemv_wrapper = """ | |||||
# emit manifest helper | |||||
manifest = SubstituteTemplate(self.manifest_template, { | |||||
'operation_name': self.operation.procedural_name(), | |||||
}) | |||||
self.kernel_file.write(manifest) | |||||
# | |||||
def __exit__(self, exception_type, exception_value, traceback): | |||||
self.kernel_file.write(self.epilogue_template) | |||||
self.kernel_file.close() | |||||
################################################################################################### | |||||
################################################################################################### | |||||
class EmitGemvSingleKernelWrapper: | |||||
def __init__(self, kernel_path, gemm_operation, wrapper_path): | |||||
self.kernel_path = kernel_path | |||||
self.wrapper_path = wrapper_path | |||||
self.operation = gemm_operation | |||||
self.wrapper_template = """ | |||||
template void megdnn::cuda::cutlass_wrapper:: | template void megdnn::cuda::cutlass_wrapper:: | ||||
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_${operation_name}>( | cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_${operation_name}>( | ||||
BatchedGemmCoord const& problem_size, | BatchedGemmCoord const& problem_size, | ||||
const typename Operation_${operation_name}::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
const typename Operation_${operation_name}::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
const typename Operation_${operation_name}::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
const typename Operation_${operation_name}::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
typename Operation_${operation_name}::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | typename Operation_${operation_name}::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | ||||
cudaStream_t stream); | cudaStream_t stream); | ||||
""" | """ | ||||
if self.operation.gemm_kind == GemmKind.SplitKParallel or \ | |||||
self.operation.gemm_kind == GemmKind.Gemm: | |||||
self.wrapper_template = gemm_wrapper | |||||
else: | |||||
assert self.operation.gemm_kind == GemmKind.GemvBatchedStrided | |||||
self.wrapper_template = gemv_wrapper | |||||
instance_emitters = { | |||||
GemmKind.Gemm: EmitGemmInstance(), | |||||
GemmKind.SplitKParallel: EmitGemmSplitKParallelInstance(), | |||||
GemmKind.GemvBatchedStrided: EmitGemvBatchedStridedInstance(), | |||||
} | |||||
self.instance_emitter = instance_emitters[self.operation.gemm_kind] | |||||
self.instance_emitter = EmitGemvBatchedStridedInstance() | |||||
self.header_template = """ | self.header_template = """ | ||||
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
// ignore warning of cutlass | // ignore warning of cutlass | ||||
#pragma GCC diagnostic push | #pragma GCC diagnostic push | ||||
#pragma GCC diagnostic ignored "-Wunused-parameter" | #pragma GCC diagnostic ignored "-Wunused-parameter" | ||||
@@ -1055,10 +1108,10 @@ ${operation_instance} | |||||
""" | """ | ||||
# | # | ||||
def __enter__(self): | def __enter__(self): | ||||
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) | |||||
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) | |||||
self.kernel_file = LazyFile(self.kernel_path) | self.kernel_file = LazyFile(self.kernel_path) | ||||
self.kernel_file.write(SubstituteTemplate(self.header_template, { | self.kernel_file.write(SubstituteTemplate(self.header_template, { | ||||
'wrapper_path': self.wrapper_path, | |||||
'wrapper_path': self.wrapper_path, | |||||
})) | })) | ||||
return self | return self | ||||
@@ -1070,7 +1123,7 @@ ${operation_instance} | |||||
# emit wrapper | # emit wrapper | ||||
wrapper = SubstituteTemplate(self.wrapper_template, { | wrapper = SubstituteTemplate(self.wrapper_template, { | ||||
'operation_name': self.operation.procedural_name(), | |||||
'operation_name': self.operation.procedural_name(), | |||||
}) | }) | ||||
self.kernel_file.write(wrapper) | self.kernel_file.write(wrapper) | ||||
@@ -1079,7 +1132,5 @@ ${operation_instance} | |||||
self.kernel_file.write(self.epilogue_template) | self.kernel_file.write(self.epilogue_template) | ||||
self.kernel_file.close() | self.kernel_file.close() | ||||
################################################################################################### | ################################################################################################### | ||||
################################################################################################### | ################################################################################################### | ||||
@@ -23,6 +23,8 @@ def write_op_list(f, gen_op, gen_type): | |||||
operations = GenerateDeconvOperations(GenArg(gen_op, gen_type)) | operations = GenerateDeconvOperations(GenArg(gen_op, gen_type)) | ||||
for op in operations: | for op in operations: | ||||
f.write(' "%s.cu",\n' % op.procedural_name()) | f.write(' "%s.cu",\n' % op.procedural_name()) | ||||
if gen_op != "gemv": | |||||
f.write(' "all_%s_%s_operations.cu",\n' % (gen_op, gen_type)) | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
@@ -292,7 +292,7 @@ def GenerateConv2d_TensorOp_8832(args): | |||||
] | ] | ||||
operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | ||||
dst_layout, dst_type, min_cc, 128, 128, 64, | dst_layout, dst_type, min_cc, 128, 128, 64, | ||||
True, ImplicitGemmMode.GemmTN, True) | |||||
False, ImplicitGemmMode.GemmTN, True) | |||||
layouts_nhwc = [ | layouts_nhwc = [ | ||||
(LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32), | (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32), | ||||
@@ -633,16 +633,10 @@ if __name__ == "__main__": | |||||
parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832'], | parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832'], | ||||
default='simt', help="kernel type of CUTLASS kernel generator") | default='simt', help="kernel type of CUTLASS kernel generator") | ||||
operation2wrapper_path = { | |||||
"gemm": "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuinl", \ | |||||
"gemv": "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl", \ | |||||
"conv2d": "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl", \ | |||||
"deconv": "src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl", \ | |||||
} | |||||
gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl" | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
wrapper_path = operation2wrapper_path[args.operations] | |||||
if args.operations == "gemm": | if args.operations == "gemm": | ||||
operations = GenerateGemmOperations(args) | operations = GenerateGemmOperations(args) | ||||
elif args.operations == "gemv": | elif args.operations == "gemv": | ||||
@@ -652,16 +646,22 @@ if __name__ == "__main__": | |||||
elif args.operations == "deconv": | elif args.operations == "deconv": | ||||
operations = GenerateDeconvOperations(args) | operations = GenerateDeconvOperations(args) | ||||
if args.operations == "conv2d" or args.operations == "deconv": | if args.operations == "conv2d" or args.operations == "deconv": | ||||
for operation in operations: | for operation in operations: | ||||
with EmitConvSingleKernelWrapper(args.output, operation, wrapper_path) as emitter: | |||||
with EmitConvSingleKernelWrapper(args.output, operation) as emitter: | |||||
emitter.emit() | emitter.emit() | ||||
elif args.operations == "gemm" or args.operations == "gemv": | |||||
elif args.operations == "gemm": | |||||
for operation in operations: | for operation in operations: | ||||
with EmitGemmSingleKernelWrapper(args.output, operation, wrapper_path) as emitter: | |||||
with EmitGemmSingleKernelWrapper(args.output, operation) as emitter: | |||||
emitter.emit() | emitter.emit() | ||||
elif args.operations == "gemv": | |||||
for operation in operations: | |||||
with EmitGemvSingleKernelWrapper(args.output, operation, gemv_wrapper_path) as emitter: | |||||
emitter.emit() | |||||
if args.operations != "gemv": | |||||
GenerateManifest(args, operations, args.output) | |||||
# | # | ||||
################################################################################################### | ################################################################################################### | ||||
@@ -137,6 +137,7 @@ cutlass_gen_list = [ | |||||
"cutlass_simt_sgemm_split_k_parallel_256x32_8x2_tt_align1.cu", | "cutlass_simt_sgemm_split_k_parallel_256x32_8x2_tt_align1.cu", | ||||
"cutlass_simt_sgemm_256x64_8x2_tt_align1.cu", | "cutlass_simt_sgemm_256x64_8x2_tt_align1.cu", | ||||
"cutlass_simt_sgemm_split_k_parallel_256x64_8x2_tt_align1.cu", | "cutlass_simt_sgemm_split_k_parallel_256x64_8x2_tt_align1.cu", | ||||
"all_gemm_simt_operations.cu", | |||||
"cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4.cu", | "cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4.cu", | ||||
"cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2.cu", | "cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2.cu", | ||||
"cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1.cu", | "cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1.cu", | ||||
@@ -169,6 +170,7 @@ cutlass_gen_list = [ | |||||
"cutlass_simt_s8_idgrad_id_s8_16x128x16_16x64x16_2_nc4hw4_k4rsc4.cu", | "cutlass_simt_s8_idgrad_id_s8_16x128x16_16x64x16_2_nc4hw4_k4rsc4.cu", | ||||
"cutlass_simt_s8_idgrad_id_s8_16x128x16_16x128x16_1_nc4hw4_k4rsc4.cu", | "cutlass_simt_s8_idgrad_id_s8_16x128x16_16x128x16_1_nc4hw4_k4rsc4.cu", | ||||
"cutlass_simt_s8_idgrad_id_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", | "cutlass_simt_s8_idgrad_id_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", | ||||
"all_deconv_simt_operations.cu", | |||||
"cutlass_simt_s8_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
@@ -373,6 +375,7 @@ cutlass_gen_list = [ | |||||
"cutlass_simt_f32_ifprop_1x1_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"cutlass_simt_f32_ifprop_1x1_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | ||||
"all_conv2d_simt_operations.cu", | |||||
"cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | "cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | ||||
"cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | ||||
"cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | ||||
@@ -481,26 +484,47 @@ cutlass_gen_list = [ | |||||
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
"cutlass_tensorop_s8_i8816fprop_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
"all_conv2d_tensorop8816_operations.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | ||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | ||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | ||||
"cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | ||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | ||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||||
"cutlass_tensorop_s4_i8832fprop_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | "cutlass_tensorop_s4_i8832fprop_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | ||||
"cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | "cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | ||||
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | "cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | ||||
@@ -621,4 +645,5 @@ cutlass_gen_list = [ | |||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | ||||
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | ||||
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | ||||
"all_conv2d_tensorop8832_operations.cu", | |||||
] | ] |
@@ -8,6 +8,7 @@ import enum | |||||
import os.path | import os.path | ||||
import shutil | import shutil | ||||
from lazy_file import LazyFile | |||||
from library import * | from library import * | ||||
from gemm_operation import * | from gemm_operation import * | ||||
from conv2d_operation import * | from conv2d_operation import * | ||||
@@ -349,3 +350,41 @@ void initialize_all(Manifest &manifest) { | |||||
# | # | ||||
################################################################################################### | ################################################################################################### | ||||
def GenerateManifest(args, operations, output_dir): | |||||
manifest_path = os.path.join(output_dir, "all_%s_%s_operations.cu" % (args.operations, args.type)) | |||||
f = LazyFile(manifest_path) | |||||
f.write(""" | |||||
/* | |||||
Generated by generator.py - Do not edit. | |||||
*/ | |||||
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
#include "cutlass/cutlass.h" | |||||
#include "src/cuda/cutlass/library.h" | |||||
#include "src/cuda/cutlass/manifest.h" | |||||
namespace cutlass { | |||||
namespace library { | |||||
""") | |||||
for op in operations: | |||||
f.write("void initialize_%s(Manifest &manifest);\n" % op.procedural_name()) | |||||
f.write(""" | |||||
void initialize_all_%s_%s_operations(Manifest &manifest) { | |||||
""" % (args.operations, args.type)) | |||||
for op in operations: | |||||
f.write(" initialize_%s(manifest);\n" % op.procedural_name()) | |||||
f.write(""" | |||||
} | |||||
} // namespace library | |||||
} // namespace cutlass | |||||
#endif | |||||
""") | |||||
f.close() |
@@ -217,68 +217,77 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { | |||||
#if CUDA_VERSION >= 10020 | #if CUDA_VERSION >= 10020 | ||||
{ | { | ||||
using AlgoParam = AlgoInt8NCHW32IMMAImplicitGemm::AlgoParam; | using AlgoParam = AlgoInt8NCHW32IMMAImplicitGemm::AlgoParam; | ||||
int8_nchw32_imma.emplace_back(AlgoParam{128, 256, 64, 64, 64, 64, 2}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{256, 128, 64, 64, 64, 64, 2}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{128, 128, 64, 64, 64, 64, 2}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 64, 64, 32, 64, 2}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 64, 32, 64, 64, 2}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 1}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 32, 32, 64, 32, 1}); | |||||
int8_nchw32_imma.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 1}); | |||||
int8_nchw32_imma.emplace_back( | |||||
AlgoParam{128, 256, 64, 64, 64, 64, 8, 8, 16, 2}); | |||||
int8_nchw32_imma.emplace_back( | |||||
AlgoParam{256, 128, 64, 64, 64, 64, 8, 8, 16, 2}); | |||||
int8_nchw32_imma.emplace_back( | |||||
AlgoParam{128, 128, 64, 64, 64, 64, 8, 8, 16, 2}); | |||||
int8_nchw32_imma.emplace_back( | |||||
AlgoParam{128, 64, 64, 64, 32, 64, 8, 8, 16, 2}); | |||||
int8_nchw32_imma.emplace_back( | |||||
AlgoParam{64, 128, 64, 32, 64, 64, 8, 8, 16, 2}); | |||||
int8_nchw32_imma.emplace_back( | |||||
AlgoParam{128, 64, 32, 64, 32, 32, 8, 8, 16, 1}); | |||||
int8_nchw32_imma.emplace_back( | |||||
AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1}); | |||||
int8_nchw32_imma.emplace_back( | |||||
AlgoParam{64, 128, 32, 32, 64, 32, 8, 8, 16, 1}); | |||||
int8_nchw32_imma.emplace_back( | |||||
AlgoParam{32, 128, 32, 32, 64, 32, 8, 8, 16, 1}); | |||||
} | } | ||||
{ | { | ||||
using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | ||||
int4_int4_nchw64_imma.emplace_back( | int4_int4_nchw64_imma.emplace_back( | ||||
AlgoParam{128, 128, 128, 64, 64, 128, 2}); | |||||
AlgoParam{128, 128, 128, 64, 64, 128, 8, 8, 32, 2}); | |||||
int4_int4_nchw64_imma.emplace_back( | int4_int4_nchw64_imma.emplace_back( | ||||
AlgoParam{128, 256, 128, 64, 64, 128, 2}); | |||||
AlgoParam{128, 256, 128, 64, 64, 128, 8, 8, 32, 2}); | |||||
int4_int4_nchw64_imma.emplace_back( | int4_int4_nchw64_imma.emplace_back( | ||||
AlgoParam{128, 64, 128, 64, 64, 128, 2}); | |||||
AlgoParam{128, 64, 128, 64, 64, 128, 8, 8, 32, 2}); | |||||
int4_int4_nchw64_imma.emplace_back( | int4_int4_nchw64_imma.emplace_back( | ||||
AlgoParam{128, 64, 64, 64, 64, 64, 1}); | |||||
AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1}); | |||||
} | } | ||||
{ | { | ||||
using AlgoParam = AlgoUInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | using AlgoParam = AlgoUInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | ||||
uint4_int4_nchw64_imma.emplace_back( | uint4_int4_nchw64_imma.emplace_back( | ||||
AlgoParam{128, 128, 128, 64, 64, 128, 2}); | |||||
AlgoParam{128, 128, 128, 64, 64, 128, 8, 8, 32, 2}); | |||||
uint4_int4_nchw64_imma.emplace_back( | uint4_int4_nchw64_imma.emplace_back( | ||||
AlgoParam{128, 256, 128, 64, 64, 128, 2}); | |||||
AlgoParam{128, 256, 128, 64, 64, 128, 8, 8, 32, 2}); | |||||
uint4_int4_nchw64_imma.emplace_back( | uint4_int4_nchw64_imma.emplace_back( | ||||
AlgoParam{128, 64, 128, 64, 64, 128, 2}); | |||||
AlgoParam{128, 64, 128, 64, 64, 128, 8, 8, 32, 2}); | |||||
uint4_int4_nchw64_imma.emplace_back( | uint4_int4_nchw64_imma.emplace_back( | ||||
AlgoParam{128, 64, 64, 64, 64, 64, 1}); | |||||
AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1}); | |||||
} | } | ||||
{ | { | ||||
using AlgoParam = AlgoInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | using AlgoParam = AlgoInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | ||||
int4_int4_nhwc_imma.emplace_back( | int4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 32, 64, 64, 32, 64, 1, 32}); | |||||
AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 32}); | |||||
int4_int4_nhwc_imma.emplace_back( | int4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 32, 64, 64, 32, 64, 1, 16}); | |||||
AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 16}); | |||||
int4_int4_nhwc_imma.emplace_back( | int4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 32, 64, 64, 32, 64, 1, 8}); | |||||
AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 8}); | |||||
int4_int4_nhwc_imma.emplace_back( | int4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 64, 64, 64, 64, 64, 1, 32}); | |||||
AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 32}); | |||||
int4_int4_nhwc_imma.emplace_back( | int4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 64, 64, 64, 64, 64, 1, 16}); | |||||
AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 16}); | |||||
int4_int4_nhwc_imma.emplace_back( | int4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 64, 64, 64, 64, 64, 1, 8}); | |||||
AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 8}); | |||||
} | } | ||||
{ | { | ||||
using AlgoParam = AlgoUInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | using AlgoParam = AlgoUInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | ||||
uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 32, 64, 64, 32, 64, 1, 32}); | |||||
AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 32}); | |||||
uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 32, 64, 64, 32, 64, 1, 16}); | |||||
AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 16}); | |||||
uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 32, 64, 64, 32, 64, 1, 8}); | |||||
AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 8}); | |||||
uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 64, 64, 64, 64, 64, 1, 32}); | |||||
AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 32}); | |||||
uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 64, 64, 64, 64, 64, 1, 16}); | |||||
AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 16}); | |||||
uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
AlgoParam{128, 64, 64, 64, 64, 64, 1, 8}); | |||||
AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 8}); | |||||
} | } | ||||
#endif | #endif | ||||
} | } | ||||
@@ -286,15 +295,24 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { | |||||
void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() { | void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() { | ||||
using AlgoParam = AlgoInt8NCHW4DotProdImplicitGemm::AlgoParam; | using AlgoParam = AlgoInt8NCHW4DotProdImplicitGemm::AlgoParam; | ||||
int8_nchw4_dotprod.emplace_back(AlgoParam{128, 128, 32, 64, 32, 32, 2}); | |||||
int8_nchw4_dotprod.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 2}); | |||||
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32, 2}); | |||||
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 2}); | |||||
int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 2}); | |||||
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32, 2}); | |||||
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32, 2}); | |||||
int8_nchw4_dotprod.emplace_back(AlgoParam{16, 128, 16, 16, 128, 16, 1}); | |||||
int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 2}); | |||||
int8_nchw4_dotprod.emplace_back( | |||||
AlgoParam{128, 128, 32, 64, 32, 32, 1, 1, 4, 2}); | |||||
int8_nchw4_dotprod.emplace_back( | |||||
AlgoParam{128, 64, 32, 64, 32, 32, 1, 1, 4, 2}); | |||||
int8_nchw4_dotprod.emplace_back( | |||||
AlgoParam{64, 128, 32, 64, 32, 32, 1, 1, 4, 2}); | |||||
int8_nchw4_dotprod.emplace_back( | |||||
AlgoParam{32, 128, 32, 32, 64, 32, 1, 1, 4, 2}); | |||||
int8_nchw4_dotprod.emplace_back( | |||||
AlgoParam{128, 32, 32, 64, 32, 32, 1, 1, 4, 2}); | |||||
int8_nchw4_dotprod.emplace_back( | |||||
AlgoParam{32, 64, 32, 32, 64, 32, 1, 1, 4, 2}); | |||||
int8_nchw4_dotprod.emplace_back( | |||||
AlgoParam{64, 32, 32, 64, 32, 32, 1, 1, 4, 2}); | |||||
int8_nchw4_dotprod.emplace_back( | |||||
AlgoParam{16, 128, 16, 16, 128, 16, 1, 1, 4, 1}); | |||||
int8_nchw4_dotprod.emplace_back( | |||||
AlgoParam{16, 64, 8, 16, 64, 8, 1, 1, 4, 2}); | |||||
} | } | ||||
ConvBiasForwardImpl::AlgoBase* | ConvBiasForwardImpl::AlgoBase* | ||||
@@ -28,6 +28,17 @@ | |||||
#include <memory> | #include <memory> | ||||
#include <unordered_map> | #include <unordered_map> | ||||
namespace cutlass { | |||||
namespace library { | |||||
// forward declaration of cutlass library concepts, we hope that algo.h does | |||||
// not depend on cutlass headers | |||||
class Operation; | |||||
} // namespace library | |||||
} // namespace cutlass | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace cuda { | namespace cuda { | ||||
@@ -505,9 +516,44 @@ public: | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8) | MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8) | ||||
}; | }; | ||||
class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final | |||||
: public AlgoBase { | |||||
/*********************** Cutlass Algorithms ************************/ | |||||
/* The inheritance of cutlass algorithm classes: | |||||
* | |||||
* AlgoCutlassConvolutionBase | |||||
* + | |||||
* +--- AlgoInt8NCHW4DotProdImplicitGemm | |||||
* +--- AlgoInt8NCHW32IMMAImplicitGemm | |||||
* + | |||||
* +--- AlgoInt4NCHW64IMMAImplicitGemmBase | |||||
* +----+--- AlgoInt4Int4NCHW64IMMAImplicitGemm | |||||
* +----+--- AlgoUInt4Int4NCHW64IMMAImplicitGemm | |||||
* + | |||||
* +--- AlgoInt4NHWCIMMAImplicitGemmBase | |||||
* +----+--- AlgoInt4Int4NHWCIMMAImplicitGemm | |||||
* +----+--- AlgoUInt4Int4NHWCIMMAImplicitGemm | |||||
* + | |||||
*/ | |||||
/* | |||||
* The base class for all cutlass algorithm classes | |||||
*/ | |||||
class ConvBiasForwardImpl::AlgoCutlassConvolutionBase : public AlgoBase { | |||||
public: | public: | ||||
// corresponds to cutlass::conv::Operator. we hope that algo.h does not | |||||
// depend on cutlass headers | |||||
enum class ConvOperator { kFprop, kDgrad, kWgrad }; | |||||
// corresponds to cutlass::conv::ConvType. we hope that algo.h does not | |||||
// depend on cutlass headers | |||||
enum class ConvType { | |||||
kConvolution, | |||||
kBatchConvolution, | |||||
kLocal, | |||||
kLocalShare | |||||
}; | |||||
// common parameters for operation selection | |||||
struct AlgoParam { | struct AlgoParam { | ||||
int threadblock_m; | int threadblock_m; | ||||
int threadblock_n; | int threadblock_n; | ||||
@@ -515,21 +561,54 @@ public: | |||||
int warp_m; | int warp_m; | ||||
int warp_n; | int warp_n; | ||||
int warp_k; | int warp_k; | ||||
int instruction_m; | |||||
int instruction_n; | |||||
int instruction_k; | |||||
int stage; | int stage; | ||||
std::string to_string() { | |||||
/// default algorithm | |||||
if (threadblock_m == 128 && threadblock_n == 128 && | |||||
threadblock_k == 32 && warp_m == 32 && warp_n == 64 && | |||||
warp_k == 32 && stage == 2) { | |||||
return ""; | |||||
} | |||||
return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, | |||||
threadblock_n, threadblock_k, warp_m, warp_n, | |||||
warp_k, stage); | |||||
} | |||||
int access_size; | |||||
AlgoParam(int threadblock_m_, int threadblock_n_, int threadblock_k_, | |||||
int warp_m_, int warp_n_, int warp_k_, int instruction_m_, | |||||
int instruction_n_, int instruction_k_, int stage_, | |||||
int access_size_ = 0); | |||||
std::string to_string() const; | |||||
}; | }; | ||||
AlgoCutlassConvolutionBase(AlgoParam algo_param) | |||||
: m_algo_param{algo_param} {} | |||||
// generate a cutlass::library::ConvolutionKey and find the corresponding | |||||
// operation (cutlass kernel) from the global OperationTable | |||||
const cutlass::library::Operation* get_cutlass_conv_op( | |||||
const SizeArgs& args, ConvOperator conv_op, ConvType conv_type, | |||||
bool load_from_const, bool without_shared_load) const; | |||||
// execute the cutlass kernel found by get_cutlass_conv_op. we give | |||||
// subclasses full freedom to decide where and how these arguments are | |||||
// extracted | |||||
void execute_cutlass_conv_op(const cutlass::library::Operation* op, | |||||
const void* src, const void* filter, | |||||
const void* bias, const void* z, void* dst, | |||||
void* workspace, size_t n, size_t hi, | |||||
size_t wi, size_t ci, size_t co, size_t fh, | |||||
size_t fw, size_t ho, size_t wo, size_t ph, | |||||
size_t pw, size_t sh, size_t sw, size_t dh, | |||||
size_t dw, const void* alpha, const void* beta, | |||||
const void* gamma, const void* delta, | |||||
const void* theta, const void* threshold, | |||||
const void* dst_scale, cudaStream_t stream, | |||||
const void* extra_param = nullptr) const; | |||||
protected: | |||||
AlgoParam m_algo_param; | |||||
}; | |||||
class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final | |||||
: public AlgoCutlassConvolutionBase { | |||||
public: | |||||
AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param) | AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param) | ||||
: m_algo_param{algo_param}, | |||||
: AlgoCutlassConvolutionBase(algo_param), | |||||
m_name{ssprintf("INT8_NCHW4_DOTPROD_IMPLICIT_GEMM%s", | m_name{ssprintf("INT8_NCHW4_DOTPROD_IMPLICIT_GEMM%s", | ||||
m_algo_param.to_string().c_str())} {} | m_algo_param.to_string().c_str())} {} | ||||
bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
@@ -555,7 +634,6 @@ public: | |||||
private: | private: | ||||
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | ||||
const SizeArgs& args) const; | const SizeArgs& args) const; | ||||
AlgoParam m_algo_param; | |||||
std::string m_name; | std::string m_name; | ||||
}; | }; | ||||
@@ -714,19 +792,10 @@ private: | |||||
#if CUDA_VERSION >= 10020 | #if CUDA_VERSION >= 10020 | ||||
class ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm final | class ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm final | ||||
: public AlgoBase { | |||||
: public AlgoCutlassConvolutionBase { | |||||
public: | public: | ||||
struct AlgoParam { | |||||
int threadblock_m; | |||||
int threadblock_n; | |||||
int threadblock_k; | |||||
int warp_m; | |||||
int warp_n; | |||||
int warp_k; | |||||
int stage; | |||||
}; | |||||
AlgoInt8NCHW32IMMAImplicitGemm(AlgoParam algo_param) | AlgoInt8NCHW32IMMAImplicitGemm(AlgoParam algo_param) | ||||
: m_algo_param{algo_param} { | |||||
: AlgoCutlassConvolutionBase(algo_param) { | |||||
m_name = ConvBias::algo_name<ConvBias::DirectParam>( | m_name = ConvBias::algo_name<ConvBias::DirectParam>( | ||||
ssprintf("INT8_NCHW32_IMMA_IMPLICIT_GEMM_%s", | ssprintf("INT8_NCHW32_IMMA_IMPLICIT_GEMM_%s", | ||||
to_string(m_algo_param).c_str()), | to_string(m_algo_param).c_str()), | ||||
@@ -757,25 +826,14 @@ private: | |||||
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | ||||
const SizeArgs& args) const; | const SizeArgs& args) const; | ||||
AlgoParam m_algo_param; | |||||
std::string m_name; | std::string m_name; | ||||
}; | }; | ||||
class ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase | class ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase | ||||
: public AlgoBase { | |||||
: public AlgoCutlassConvolutionBase { | |||||
public: | public: | ||||
struct AlgoParam { | |||||
int threadblock_m; | |||||
int threadblock_n; | |||||
int threadblock_k; | |||||
int warp_m; | |||||
int warp_n; | |||||
int warp_k; | |||||
int stage; | |||||
}; | |||||
AlgoInt4NCHW64IMMAImplicitGemmBase(AlgoParam algo_param) | AlgoInt4NCHW64IMMAImplicitGemmBase(AlgoParam algo_param) | ||||
: m_algo_param(algo_param) {} | |||||
: AlgoCutlassConvolutionBase(algo_param) {} | |||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | return AlgoAttribute::REPRODUCIBLE; | ||||
@@ -799,16 +857,9 @@ protected: | |||||
virtual std::tuple<float, float, float, float, float> get_constants( | virtual std::tuple<float, float, float, float, float> get_constants( | ||||
const ExecArgs& args) const = 0; | const ExecArgs& args) const = 0; | ||||
virtual void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, | |||||
void* z_ptr, convolution::ConvParam kern_param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, | |||||
float gamma, float delta, float theta, | |||||
cudaStream_t stream) const = 0; | |||||
void reorder_filter(const ExecArgs& args, void* reordered_filter) const; | void reorder_filter(const ExecArgs& args, void* reordered_filter) const; | ||||
std::string m_name; | std::string m_name; | ||||
AlgoParam m_algo_param; | |||||
}; | }; | ||||
class ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm final | class ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm final | ||||
@@ -842,11 +893,6 @@ private: | |||||
std::tuple<float, float, float, float, float> get_constants( | std::tuple<float, float, float, float, float> get_constants( | ||||
const ExecArgs& args) const override; | const ExecArgs& args) const override; | ||||
void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, | |||||
void* z_ptr, convolution::ConvParam kern_param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float delta, float theta, cudaStream_t stream) const override; | |||||
}; | }; | ||||
class ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm final | class ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm final | ||||
@@ -881,30 +927,15 @@ private: | |||||
std::tuple<float, float, float, float, float> get_constants( | std::tuple<float, float, float, float, float> get_constants( | ||||
const ExecArgs& args) const override; | const ExecArgs& args) const override; | ||||
void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, | |||||
void* z_ptr, convolution::ConvParam kern_param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float delta, float theta, cudaStream_t stream) const override; | |||||
void update_bias(const ExecArgs& args, void* updated_bias, | void update_bias(const ExecArgs& args, void* updated_bias, | ||||
void* reduce_filter_ptr, void* reduce_workspace) const; | void* reduce_filter_ptr, void* reduce_workspace) const; | ||||
}; | }; | ||||
class ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase : public AlgoBase { | |||||
class ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase | |||||
: public AlgoCutlassConvolutionBase { | |||||
public: | public: | ||||
struct AlgoParam { | |||||
int threadblock_m; | |||||
int threadblock_n; | |||||
int threadblock_k; | |||||
int warp_m; | |||||
int warp_n; | |||||
int warp_k; | |||||
int stage; | |||||
int access_size; | |||||
}; | |||||
AlgoInt4NHWCIMMAImplicitGemmBase(AlgoParam algo_param) | AlgoInt4NHWCIMMAImplicitGemmBase(AlgoParam algo_param) | ||||
: m_algo_param(algo_param) {} | |||||
: AlgoCutlassConvolutionBase(algo_param) {} | |||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | return AlgoAttribute::REPRODUCIBLE; | ||||
@@ -928,17 +959,10 @@ protected: | |||||
virtual std::tuple<float, float, float, float, float> get_constants( | virtual std::tuple<float, float, float, float, float> get_constants( | ||||
const ExecArgs& args) const = 0; | const ExecArgs& args) const = 0; | ||||
virtual void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, | |||||
void* z_ptr, convolution::ConvParam kern_param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, | |||||
float gamma, float delta, float theta, | |||||
cudaStream_t stream) const = 0; | |||||
void reorder_filter(const ExecArgs& args, int interleaved, | void reorder_filter(const ExecArgs& args, int interleaved, | ||||
void* reordered_filter) const; | void* reordered_filter) const; | ||||
std::string m_name; | std::string m_name; | ||||
AlgoParam m_algo_param; | |||||
}; | }; | ||||
class ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm final | class ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm final | ||||
@@ -971,11 +995,6 @@ private: | |||||
std::tuple<float, float, float, float, float> get_constants( | std::tuple<float, float, float, float, float> get_constants( | ||||
const ExecArgs& args) const override; | const ExecArgs& args) const override; | ||||
void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, | |||||
void* z_ptr, convolution::ConvParam kern_param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float delta, float theta, cudaStream_t stream) const override; | |||||
}; | }; | ||||
class ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm final | class ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm final | ||||
@@ -1009,11 +1028,6 @@ private: | |||||
std::tuple<float, float, float, float, float> get_constants( | std::tuple<float, float, float, float, float> get_constants( | ||||
const ExecArgs& args) const override; | const ExecArgs& args) const override; | ||||
void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, | |||||
void* z_ptr, convolution::ConvParam kern_param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float delta, float theta, cudaStream_t stream) const override; | |||||
void update_bias(const ExecArgs& args, void* updated_bias, | void update_bias(const ExecArgs& args, void* updated_bias, | ||||
void* reduce_filter_ptr, void* reduce_workspace) const; | void* reduce_filter_ptr, void* reduce_workspace) const; | ||||
}; | }; | ||||
@@ -0,0 +1,253 @@ | |||||
/** | |||||
* \file dnn/src/cuda/conv_bias/cutlass_convolution_base.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/cuda/conv_bias/algo.h" | |||||
#include "src/cuda/cutlass/singleton.h" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
using namespace cutlass::library; | |||||
using namespace cutlass::epilogue; | |||||
ConvBiasForwardImpl::AlgoCutlassConvolutionBase::AlgoParam::AlgoParam( | |||||
int threadblock_m_, int threadblock_n_, int threadblock_k_, int warp_m_, | |||||
int warp_n_, int warp_k_, int instruction_m_, int instruction_n_, | |||||
int instruction_k_, int stage_, int access_size_) | |||||
: threadblock_m(threadblock_m_), | |||||
threadblock_n(threadblock_n_), | |||||
threadblock_k(threadblock_k_), | |||||
warp_m(warp_m_), | |||||
warp_n(warp_n_), | |||||
warp_k(warp_k_), | |||||
instruction_m(instruction_m_), | |||||
instruction_n(instruction_m_), | |||||
instruction_k(instruction_k_), | |||||
stage(stage_), | |||||
access_size(access_size_) {} | |||||
std::string | |||||
ConvBiasForwardImpl::AlgoCutlassConvolutionBase::AlgoParam::to_string() const { | |||||
/// default algorithm | |||||
if (threadblock_m == 128 && threadblock_n == 128 && threadblock_k == 32 && | |||||
warp_m == 32 && warp_n == 64 && warp_k == 32 && stage == 2) { | |||||
return ""; | |||||
} | |||||
return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n, | |||||
threadblock_k, warp_m, warp_n, warp_k, stage); | |||||
} | |||||
namespace { | |||||
using Base = ConvBiasForwardImpl::AlgoCutlassConvolutionBase; | |||||
cutlass::conv::Operator convert_conv_op(Base::ConvOperator conv_op) { | |||||
switch (conv_op) { | |||||
case Base::ConvOperator::kFprop: | |||||
return cutlass::conv::Operator::kFprop; | |||||
case Base::ConvOperator::kDgrad: | |||||
return cutlass::conv::Operator::kDgrad; | |||||
case Base::ConvOperator::kWgrad: | |||||
return cutlass::conv::Operator::kWgrad; | |||||
default: | |||||
megdnn_assert(0, "invalid conv op"); | |||||
} | |||||
} | |||||
cutlass::conv::ConvType convert_conv_type(Base::ConvType conv_type) { | |||||
switch (conv_type) { | |||||
case Base::ConvType::kConvolution: | |||||
return cutlass::conv::ConvType::kConvolution; | |||||
case Base::ConvType::kBatchConvolution: | |||||
return cutlass::conv::ConvType::kBatchConvolution; | |||||
case Base::ConvType::kLocal: | |||||
return cutlass::conv::ConvType::kLocal; | |||||
case Base::ConvType::kLocalShare: | |||||
return cutlass::conv::ConvType::kLocalShare; | |||||
default: | |||||
megdnn_assert(0, "invalid conv type"); | |||||
} | |||||
} | |||||
NumericTypeID convert_dtype(DTypeEnum dtype) { | |||||
switch (dtype) { | |||||
case DTypeEnum::Float32: | |||||
return NumericTypeID::kF32; | |||||
case DTypeEnum::Float16: | |||||
return NumericTypeID::kF16; | |||||
case DTypeEnum::Int8: | |||||
return NumericTypeID::kS8; | |||||
case DTypeEnum::QuantizedS32: | |||||
return NumericTypeID::kS32; | |||||
case DTypeEnum::QuantizedS8: | |||||
return NumericTypeID::kS8; | |||||
case DTypeEnum::QuantizedS4: | |||||
return NumericTypeID::kS4; | |||||
case DTypeEnum::Quantized4Asymm: | |||||
return NumericTypeID::kU4; | |||||
default: | |||||
megdnn_assert(0, "invalid dtype"); | |||||
} | |||||
} | |||||
struct LayoutPack { | |||||
LayoutTypeID src; | |||||
LayoutTypeID filter; | |||||
LayoutTypeID dst; | |||||
LayoutTypeID bias; | |||||
}; | |||||
LayoutPack get_layout_pack(const param::ConvBias::Format format, | |||||
int access_type) { | |||||
using Format = param::ConvBias::Format; | |||||
switch (format) { | |||||
case Format::NCHW4: | |||||
return {LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorC4RSK4, | |||||
LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorNC4HW4}; | |||||
case Format::NCHW4_NCHW: | |||||
return {LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorC4RSK4, | |||||
LayoutTypeID::kTensorNCHW, LayoutTypeID::kTensorNCHW}; | |||||
case Format::NCHW4_NHWC: | |||||
return {LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorC4RSK4, | |||||
LayoutTypeID::kTensorNHWC, LayoutTypeID::kTensorNHWC}; | |||||
case Format::NCHW4_NCHW32: | |||||
return {LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorC4RSK4, | |||||
LayoutTypeID::kTensorNC32HW32, | |||||
LayoutTypeID::kTensorNC32HW32}; | |||||
case Format::NCHW32: | |||||
return {LayoutTypeID::kTensorNC32HW32, | |||||
LayoutTypeID::kTensorC32RSK32, | |||||
LayoutTypeID::kTensorNC32HW32, | |||||
LayoutTypeID::kTensorNC32HW32}; | |||||
case Format::NCHW32_NCHW4: | |||||
return {LayoutTypeID::kTensorNC32HW32, | |||||
LayoutTypeID::kTensorC32RSK32, LayoutTypeID::kTensorNC4HW4, | |||||
LayoutTypeID::kTensorNC4HW4}; | |||||
case Format::NCHW64: | |||||
return {LayoutTypeID::kTensorNC64HW64, | |||||
LayoutTypeID::kTensorC64RSK64, | |||||
LayoutTypeID::kTensorNC64HW64, | |||||
LayoutTypeID::kTensorNC64HW64}; | |||||
case Format::NHWC: | |||||
switch (access_type) { | |||||
case 8: | |||||
return {LayoutTypeID::kTensorNHWC, | |||||
LayoutTypeID::kTensorNC8HW8, | |||||
LayoutTypeID::kTensorNHWC, | |||||
LayoutTypeID::kTensorNHWC}; | |||||
case 16: | |||||
return {LayoutTypeID::kTensorNHWC, | |||||
LayoutTypeID::kTensorNC16HW16, | |||||
LayoutTypeID::kTensorNHWC, | |||||
LayoutTypeID::kTensorNHWC}; | |||||
case 32: | |||||
return {LayoutTypeID::kTensorNHWC, | |||||
LayoutTypeID::kTensorNC32HW32, | |||||
LayoutTypeID::kTensorNHWC, | |||||
LayoutTypeID::kTensorNHWC}; | |||||
default: | |||||
megdnn_assert(0, "invalid access_type"); | |||||
} | |||||
default: | |||||
megdnn_assert(0, "invalid format"); | |||||
} | |||||
} | |||||
EpilogueType get_epilogue_type(const param::ConvBias::NonlineMode mode, | |||||
bool clamp) { | |||||
using NonlineMode = param::ConvBias::NonlineMode; | |||||
if (clamp) { | |||||
if (mode == NonlineMode::IDENTITY) { | |||||
return EpilogueType::kBiasAddLinearCombinationClamp; | |||||
} else if (mode == NonlineMode::RELU) { | |||||
return EpilogueType::kBiasAddLinearCombinationReluClamp; | |||||
} else if (mode == NonlineMode::H_SWISH) { | |||||
return EpilogueType::kBiasAddLinearCombinationHSwishClamp; | |||||
} | |||||
} else { | |||||
if (mode == NonlineMode::IDENTITY) { | |||||
return EpilogueType::kBiasAddLinearCombination; | |||||
} else if (mode == NonlineMode::RELU) { | |||||
return EpilogueType::kBiasAddLinearCombinationRelu; | |||||
} else if (mode == NonlineMode::H_SWISH) { | |||||
return EpilogueType::kBiasAddLinearCombinationHSwish; | |||||
} | |||||
} | |||||
megdnn_assert(0, "invalid nonlinear mode"); | |||||
} | |||||
} // namespace | |||||
const Operation* | |||||
ConvBiasForwardImpl::AlgoCutlassConvolutionBase::get_cutlass_conv_op( | |||||
const SizeArgs& args, ConvOperator conv_op, ConvType conv_type, | |||||
bool load_from_const, bool without_shared_load) const { | |||||
using Format = param::ConvBias::Format; | |||||
auto&& param = args.opr->param(); | |||||
auto layouts = get_layout_pack(param.format, m_algo_param.access_size); | |||||
auto epilogue_type = get_epilogue_type(param.nonlineMode, | |||||
param.format != Format::NCHW4_NCHW); | |||||
ConvolutionKey key{convert_conv_op(conv_op), | |||||
convert_dtype(args.src_layout->dtype.enumv()), | |||||
layouts.src, | |||||
convert_dtype(args.filter_layout->dtype.enumv()), | |||||
layouts.filter, | |||||
convert_dtype(args.dst_layout->dtype.enumv()), | |||||
layouts.dst, | |||||
convert_dtype(args.bias_layout->dtype.enumv()), | |||||
layouts.bias, | |||||
convert_conv_type(conv_type), | |||||
m_algo_param.threadblock_m, | |||||
m_algo_param.threadblock_n, | |||||
m_algo_param.threadblock_k, | |||||
m_algo_param.warp_m, | |||||
m_algo_param.warp_n, | |||||
m_algo_param.warp_k, | |||||
m_algo_param.instruction_m, | |||||
m_algo_param.instruction_n, | |||||
m_algo_param.instruction_k, | |||||
epilogue_type, | |||||
m_algo_param.stage, | |||||
load_from_const, | |||||
without_shared_load}; | |||||
return Singleton::get().operation_table.find_op(key); | |||||
} | |||||
void ConvBiasForwardImpl::AlgoCutlassConvolutionBase::execute_cutlass_conv_op( | |||||
const Operation* op, const void* src, const void* filter, | |||||
const void* bias, const void* z, void* dst, void* workspace, size_t n, | |||||
size_t hi, size_t wi, size_t ci, size_t co, size_t fh, size_t fw, | |||||
size_t ho, size_t wo, size_t ph, size_t pw, size_t sh, size_t sw, | |||||
size_t dh, size_t dw, const void* alpha, const void* beta, | |||||
const void* gamma, const void* delta, const void* theta, | |||||
const void* threshold, const void* dst_scale, cudaStream_t stream, | |||||
const void* extra_param) const { | |||||
// gcc prints warnings when size_t values are implicitly narrowed to int | |||||
cutlass::conv::Conv2dProblemSize problem_size{ | |||||
int(n), int(hi), int(wi), int(ci), | |||||
int(co), int(fh), int(fw), int(ho), | |||||
int(wo), int(ph), int(pw), int(sh), | |||||
int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation}; | |||||
ConvolutionArguments conv_args{ | |||||
problem_size, src, filter, bias, z, | |||||
dst, alpha, beta, gamma, delta, | |||||
theta, threshold, dst_scale, extra_param}; | |||||
cutlass_check(op->run(&conv_args, workspace, stream)); | |||||
} | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -1,129 +0,0 @@ | |||||
/** | |||||
* \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "cutlass/gemm/gemm.h" | |||||
#include "src/cuda/convolution_helper/parameter.cuh" | |||||
#include "src/cuda/utils.cuh" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace cutlass_wrapper { | |||||
using GemmCoord = cutlass::gemm::GemmCoord; | |||||
template <typename Convolution> | |||||
void cutlass_convolution_wrapper( | |||||
const typename Convolution::ElementSrc* d_src, | |||||
const typename Convolution::ElementFilter* d_filter, | |||||
const typename Convolution::ElementBias* d_bias, | |||||
const typename Convolution::ElementDst* d_z, | |||||
typename Convolution::ElementDst* d_dst, int* workspace, | |||||
typename Convolution::ConvolutionParameter const& conv_param, | |||||
typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
cudaStream_t stream, typename Convolution::ExtraParam extra_param = {}); | |||||
template <bool NeedLoadFromConstMem> | |||||
void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( | |||||
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||||
const int8_t* d_z, int8_t* d_dst, int* workspace, | |||||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
float alpha, float beta, float gamma, float scale, | |||||
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
int stages, cudaStream_t stream); | |||||
template <bool NeedLoadFromConstMem> | |||||
void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( | |||||
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||||
const int8_t* d_z, int8_t* d_dst, int* workspace, | |||||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
float alpha, float beta, float gamma, float scale, | |||||
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
int stages, cudaStream_t stream); | |||||
template <bool NeedLoadFromConstMem> | |||||
void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||||
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||||
const int8_t* d_z, int8_t* d_dst, int* workspace, | |||||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
float alpha, float beta, float gamma, float scale, | |||||
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
int stages, cudaStream_t stream); | |||||
template <bool NeedLoadFromConstMem> | |||||
void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw( | |||||
const int8_t* d_src, const int8_t* d_filter, const float* d_bias, | |||||
const float* d_z, float* d_dst, int* workspace, | |||||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
float alpha, float beta, float gamma, float scale, | |||||
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
int stages, cudaStream_t stream); | |||||
template <bool NeedLoadFromConstMem> | |||||
void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32( | |||||
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||||
const int8_t* d_z, int8_t* d_dst, int* workspace, | |||||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
float alpha, float beta, float gamma, float scale, | |||||
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
int stages, cudaStream_t stream); | |||||
template <bool NeedLoadFromConstMem> | |||||
void do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||||
const int8_t* d_z, int8_t* d_dst, int* workspace, | |||||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
float alpha, float beta, float gamma, float scale, | |||||
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
int stages, cudaStream_t stream); | |||||
template <bool NeedLoadFromConstMem> | |||||
void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
const uint8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||||
const uint8_t* d_z, uint8_t* d_dst, int* workspace, | |||||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
float alpha, float beta, float gamma, float delta, float theta, | |||||
float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||||
const GemmCoord& warp_shape, int stages, cudaStream_t stream); | |||||
template <bool signedness> | |||||
void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( | |||||
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||||
const int8_t* d_z, int8_t* d_dst, int* workspace, | |||||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
float alpha, float beta, float gamma, float delta, float theta, | |||||
float scale, const GemmCoord& threadblock_shape, | |||||
const GemmCoord& warp_shape, int stages, cudaStream_t stream); | |||||
template <bool NeedLoadFromConstMem> | |||||
void do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||||
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||||
const int8_t* d_z, int8_t* d_dst, int* workspace, | |||||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
float alpha, float beta, float gamma, float scale, | |||||
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
const int32_t access_size, int stages, cudaStream_t stream); | |||||
template <bool NeedLoadFromConstMem> | |||||
void do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||||
const uint8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||||
const uint8_t* d_z, uint8_t* d_dst, int* workspace, | |||||
const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
float alpha, float beta, float gamma, float delta, float theta, | |||||
float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||||
const GemmCoord& warp_shape, const int32_t access_size, int stages, | |||||
cudaStream_t stream); | |||||
} // namespace cutlass_wrapper | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cuda.doxygen |
@@ -1,595 +0,0 @@ | |||||
/** | |||||
* \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
// ignore warning of cutlass | |||||
#pragma GCC diagnostic push | |||||
#pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
#if !MEGDNN_TEGRA_X1 | |||||
#include "cutlass/convolution/device/convolution.h" | |||||
#endif | |||||
#include "src/common/opr_param_defs_enumv.cuh" | |||||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
#pragma GCC diagnostic pop | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
using namespace cutlass_wrapper; | |||||
/* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */ | |||||
#if MEGDNN_TEGRA_X1 | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
int8_t* /* d_dst */, int* /* workspace */, | |||||
const convolution::ConvParam& /* param */, | |||||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
float /* beta */, float /* gamma */, float /* scale */, | |||||
const GemmCoord& /* threadblock_shape */, | |||||
const GemmCoord& /* warp_shape */, int /* stages */, | |||||
cudaStream_t /* stream */) {} | |||||
#else | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
const int8_t* d_src, const int8_t* d_filter, | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
int* workspace, const convolution::ConvParam& param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float scale, const GemmCoord& threadblock_shape, | |||||
const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_, warp_m_, warp_n_, \ | |||||
warp_k_, stage_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | |||||
threadblock_shape.n() == threadblock_n_ && \ | |||||
threadblock_shape.k() == threadblock_k_ && \ | |||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
warp_shape.k() == warp_k_ && stages == stage_) { \ | |||||
using ThreadBlockShape = \ | |||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_>; \ | |||||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
using Convolution = cutlass::conv::device::Convolution< \ | |||||
cutlass::int4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||||
cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||||
ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
cutlass::conv::ConvType::kConvolution, \ | |||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
cutlass::conv::threadblock:: \ | |||||
ConvolutionFpropTransThreadblockSwizzle, \ | |||||
stage_, 32, 32, NeedLoadFromConstMem, \ | |||||
cutlass::arch::OpMultiplyAddSaturate, \ | |||||
cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ | |||||
typename Convolution::ConvolutionParameter conv_param( \ | |||||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
return cutlass_convolution_wrapper<Convolution>( \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||||
reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, \ | |||||
conv_param, epilogue, stream); \ | |||||
} | |||||
#define DISPATCH_KERNEL \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 128, 64, 64, 128, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 128, 64, 64, 128, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1); \ | |||||
megdnn_assert(false, \ | |||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
"(%dx%dx%d)", \ | |||||
threadblock_shape.m(), threadblock_shape.n(), \ | |||||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
warp_shape.k()); | |||||
using ElementOutput = cutlass::int4b_t; | |||||
using ElementAccumulator = int32_t; | |||||
using ElementBias = int32_t; | |||||
using ElementCompute = float; | |||||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
switch (nonlinear_mode) { | |||||
case NonlineMode::IDENTITY: { | |||||
using EpilogueOp = | |||||
cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::RELU: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationReluClamp< | |||||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::H_SWISH: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationHSwishClamp< | |||||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
default: | |||||
megdnn_assert(false, | |||||
"unsupported nonlinear mode for conv bias operator"); | |||||
} | |||||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
#undef DISPATCH_KERNEL | |||||
} | |||||
#endif | |||||
#define INST(need_load_from_const_mem) \ | |||||
template void megdnn::cuda::cutlass_wrapper:: \ | |||||
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||||
need_load_from_const_mem>( \ | |||||
const int8_t* d_src, const int8_t* d_filter, \ | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
int* workspace, const convolution::ConvParam& param, \ | |||||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
float gamma, float scale, \ | |||||
const GemmCoord& threadblock_shape, \ | |||||
const GemmCoord& warp_shape, int stages, \ | |||||
cudaStream_t stream); | |||||
INST(true); | |||||
#undef INST | |||||
/* ====== cutlass kernel wrapper for uint4 x int4 nchw64 layout ====== */ | |||||
#if MEGDNN_TEGRA_X1 | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||||
uint8_t* /* d_dst */, int* /* workspace */, | |||||
const convolution::ConvParam& /* param */, | |||||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
float /* beta */, float /* gamma */, float /* delta */, | |||||
float /* theta */, float /* scale */, | |||||
uint8_t /* src_zero_point */, | |||||
const GemmCoord& /* threadblock_shape */, | |||||
const GemmCoord& /* warp_shape */, int /* stages */, | |||||
cudaStream_t /* stream */) {} | |||||
#else | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
const uint8_t* d_src, const int8_t* d_filter, | |||||
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||||
int* workspace, const convolution::ConvParam& param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float delta, float theta, float /* scale */, | |||||
uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||||
const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_, warp_m_, warp_n_, \ | |||||
warp_k_, stage_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | |||||
threadblock_shape.n() == threadblock_n_ && \ | |||||
threadblock_shape.k() == threadblock_k_ && \ | |||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
warp_shape.k() == warp_k_ && stages == stage_) { \ | |||||
using ThreadBlockShape = \ | |||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_>; \ | |||||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
using Convolution = cutlass::conv::device::Convolution< \ | |||||
cutlass::uint4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||||
cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||||
ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
cutlass::conv::ConvType::kConvolution, \ | |||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
cutlass::conv::threadblock:: \ | |||||
ConvolutionFpropTransThreadblockSwizzle, \ | |||||
stage_, 32, 32, NeedLoadFromConstMem, \ | |||||
cutlass::arch::OpMultiplyAddSaturate, \ | |||||
cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ | |||||
typename Convolution::ConvolutionParameter conv_param( \ | |||||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
return cutlass_convolution_wrapper<Convolution>( \ | |||||
reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||||
reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||||
conv_param, epilogue, stream, {src_zero_point}); \ | |||||
} | |||||
#define DISPATCH_KERNEL \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 128, 64, 64, 128, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 128, 64, 64, 128, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1); \ | |||||
megdnn_assert(false, \ | |||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
"(%dx%dx%d)", \ | |||||
threadblock_shape.m(), threadblock_shape.n(), \ | |||||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
warp_shape.k()); | |||||
using ElementOutput = cutlass::uint4b_t; | |||||
using ElementAccumulator = int32_t; | |||||
using ElementBias = int32_t; | |||||
using ElementCompute = float; | |||||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
switch (nonlinear_mode) { | |||||
case NonlineMode::IDENTITY: { | |||||
using EpilogueOp = | |||||
cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
delta + theta}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::RELU: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationReluClamp< | |||||
ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
0, delta, theta}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
default: | |||||
megdnn_assert(false, | |||||
"unsupported nonlinear mode for conv bias operator"); | |||||
} | |||||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
#undef DISPATCH_KERNEL | |||||
} | |||||
#endif | |||||
#define INST(need_load_from_const_mem) \ | |||||
template void megdnn::cuda::cutlass_wrapper:: \ | |||||
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||||
need_load_from_const_mem>( \ | |||||
const uint8_t* d_src, const int8_t* d_filter, \ | |||||
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||||
int* workspace, const convolution::ConvParam& param, \ | |||||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
float gamma, float delta, float theta, float scale, \ | |||||
uint8_t src_zero_point, \ | |||||
const GemmCoord& threadblock_shape, \ | |||||
const GemmCoord& warp_shape, int stages, \ | |||||
cudaStream_t stream); | |||||
INST(true); | |||||
#undef INST | |||||
/* ====== cutlass kernel wrapper for int4 x int4 nhwc layout ====== */ | |||||
#if MEGDNN_TEGRA_X1 | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||||
const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
int8_t* /* d_dst */, int* /* workspace */, | |||||
const convolution::ConvParam& /* param */, | |||||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
float /* beta */, float /* gamma */, float /* scale */, | |||||
const GemmCoord& /* threadblock_shape */, | |||||
const GemmCoord& /* warp_shape */, | |||||
const int32_t /* access_size */, int /* stages */, | |||||
cudaStream_t /* stream */) {} | |||||
#else | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||||
const int8_t* d_src, const int8_t* d_filter, | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
int* workspace, const convolution::ConvParam& param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float scale, const GemmCoord& threadblock_shape, | |||||
const GemmCoord& warp_shape, const int32_t access_size, | |||||
int stages, cudaStream_t stream) { | |||||
bool without_shared_load = | |||||
((param.co % threadblock_shape.n() == 0) && | |||||
(threadblock_shape.n() == 32 || threadblock_shape.n() == 64)); | |||||
int out_elements_per_access = | |||||
without_shared_load ? threadblock_shape.n() / 4 : 8; | |||||
#define RUN_CUTLASS_WRAPPER(stage_, access_size_, without_shared_load_) \ | |||||
using Convolution = cutlass::conv::device::Convolution< \ | |||||
cutlass::int4b_t, cutlass::layout::TensorNHWC, cutlass::int4b_t, \ | |||||
cutlass::layout::TensorNCxHWx<access_size_>, ElementOutput, \ | |||||
cutlass::layout::TensorNHWC, int32_t, cutlass::layout::TensorNHWC, \ | |||||
int32_t, cutlass::conv::ConvType::kConvolution, \ | |||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
cutlass::conv::threadblock:: \ | |||||
ConvolutionFpropTransThreadblockSwizzle, \ | |||||
stage_, access_size_, access_size_, NeedLoadFromConstMem, \ | |||||
cutlass::arch::OpMultiplyAddSaturate, \ | |||||
cutlass::conv::ImplicitGemmMode::GEMM_TN, without_shared_load_>; \ | |||||
typename Convolution::ConvolutionParameter conv_param( \ | |||||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
return cutlass_convolution_wrapper<Convolution>( \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||||
reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, conv_param, \ | |||||
epilogue, stream); | |||||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE( \ | |||||
threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ | |||||
warp_k_, stage_, access_size_, out_elements_per_access_, \ | |||||
without_shared_load_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | |||||
threadblock_shape.n() == threadblock_n_ && \ | |||||
threadblock_shape.k() == threadblock_k_ && \ | |||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
warp_shape.k() == warp_k_ && stages == stage_ && \ | |||||
access_size == access_size_ && \ | |||||
out_elements_per_access == out_elements_per_access_ && \ | |||||
without_shared_load == without_shared_load_) { \ | |||||
using ThreadBlockShape = \ | |||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_>; \ | |||||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
using ElementOutput = cutlass::int4b_t; \ | |||||
using ElementAccumulator = int32_t; \ | |||||
using ElementBias = int32_t; \ | |||||
using ElementCompute = float; \ | |||||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; \ | |||||
switch (nonlinear_mode) { \ | |||||
case NonlineMode::IDENTITY: { \ | |||||
using EpilogueOp = cutlass::epilogue::thread:: \ | |||||
BiasAddLinearCombinationClamp< \ | |||||
ElementOutput, out_elements_per_access_, \ | |||||
ElementAccumulator, ElementBias, \ | |||||
ElementCompute>; \ | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma}; \ | |||||
RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||||
without_shared_load_); \ | |||||
} \ | |||||
case NonlineMode::RELU: { \ | |||||
using EpilogueOp = cutlass::epilogue::thread:: \ | |||||
BiasAddLinearCombinationReluClamp< \ | |||||
ElementOutput, out_elements_per_access_, \ | |||||
ElementAccumulator, ElementBias, \ | |||||
ElementCompute>; \ | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; \ | |||||
RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||||
without_shared_load_); \ | |||||
} \ | |||||
case NonlineMode::H_SWISH: { \ | |||||
using EpilogueOp = cutlass::epilogue::thread:: \ | |||||
BiasAddLinearCombinationHSwishClamp< \ | |||||
ElementOutput, out_elements_per_access_, \ | |||||
ElementAccumulator, ElementBias, \ | |||||
ElementCompute>; \ | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ | |||||
scale}; \ | |||||
RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||||
without_shared_load_); \ | |||||
} \ | |||||
default: \ | |||||
megdnn_assert( \ | |||||
false, \ | |||||
"unsupported nonlinear mode for conv bias operator"); \ | |||||
} \ | |||||
} | |||||
#define DISPATCH_KERNEL \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, true); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, true); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, true); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 16, true); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 16, true); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 16, true); \ | |||||
megdnn_assert(false, \ | |||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
"(%dx%dx%d) and access_size (%d)", \ | |||||
threadblock_shape.m(), threadblock_shape.n(), \ | |||||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
warp_shape.k(), access_size); | |||||
DISPATCH_KERNEL; | |||||
#undef RUN_CUTLASS_WRAPPER | |||||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
#undef DISPATCH_KERNEL | |||||
} | |||||
#endif | |||||
#define INST(need_load_from_const_mem) \ | |||||
template void megdnn::cuda::cutlass_wrapper:: \ | |||||
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc< \ | |||||
need_load_from_const_mem>( \ | |||||
const int8_t* d_src, const int8_t* d_filter, \ | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
int* workspace, const convolution::ConvParam& param, \ | |||||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
float gamma, float scale, \ | |||||
const GemmCoord& threadblock_shape, \ | |||||
const GemmCoord& warp_shape, const int32_t access_size, \ | |||||
int stages, cudaStream_t stream); | |||||
INST(true); | |||||
INST(false); | |||||
#undef INST | |||||
/* ====== cutlass kernel wrapper for uint4 x int4 nhwc layout ====== */ | |||||
#if MEGDNN_TEGRA_X1 | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||||
const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||||
uint8_t* /* d_dst */, int* /* workspace */, | |||||
const convolution::ConvParam& /* param */, | |||||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
float /* beta */, float /* gamma */, float /* delta */, | |||||
float /* theta */, float /* scale */, | |||||
uint8_t /* src_zero_point */, | |||||
const GemmCoord& /* threadblock_shape */, | |||||
const GemmCoord& /* warp_shape */, | |||||
const int32_t /* access_size */, int /* stages */, | |||||
cudaStream_t /* stream */) {} | |||||
#else | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||||
const uint8_t* d_src, const int8_t* d_filter, | |||||
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||||
int* workspace, const convolution::ConvParam& param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float delta, float theta, float /* scale */, | |||||
uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||||
const GemmCoord& warp_shape, const int32_t access_size, | |||||
int stages, cudaStream_t stream) { | |||||
bool without_shared_load = | |||||
((param.co % threadblock_shape.n() == 0) && | |||||
(threadblock_shape.n() == 32 || threadblock_shape.n() == 64)); | |||||
int out_elements_per_access = | |||||
without_shared_load ? threadblock_shape.n() / 4 : 8; | |||||
#define RUN_CUTLASS_WRAPPER(stage_, access_size_, without_shared_load_) \ | |||||
using Convolution = cutlass::conv::device::Convolution< \ | |||||
cutlass::uint4b_t, cutlass::layout::TensorNHWC, cutlass::int4b_t, \ | |||||
cutlass::layout::TensorNCxHWx<access_size_>, ElementOutput, \ | |||||
cutlass::layout::TensorNHWC, int32_t, cutlass::layout::TensorNHWC, \ | |||||
int32_t, cutlass::conv::ConvType::kConvolution, \ | |||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
cutlass::conv::threadblock:: \ | |||||
ConvolutionFpropTransThreadblockSwizzle, \ | |||||
stage_, access_size_, access_size_, NeedLoadFromConstMem, \ | |||||
cutlass::arch::OpMultiplyAddSaturate, \ | |||||
cutlass::conv::ImplicitGemmMode::GEMM_TN, without_shared_load_>; \ | |||||
typename Convolution::ConvolutionParameter conv_param( \ | |||||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
return cutlass_convolution_wrapper<Convolution>( \ | |||||
reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||||
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||||
reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||||
conv_param, epilogue, stream, {src_zero_point}); | |||||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE( \ | |||||
threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ | |||||
warp_k_, stage_, access_size_, out_elements_per_access_, \ | |||||
without_shared_load_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | |||||
threadblock_shape.n() == threadblock_n_ && \ | |||||
threadblock_shape.k() == threadblock_k_ && \ | |||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
warp_shape.k() == warp_k_ && stages == stage_ && \ | |||||
access_size == access_size_ && \ | |||||
out_elements_per_access == out_elements_per_access_ && \ | |||||
without_shared_load == without_shared_load_) { \ | |||||
using ThreadBlockShape = \ | |||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_>; \ | |||||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
using ElementOutput = cutlass::uint4b_t; \ | |||||
using ElementAccumulator = int32_t; \ | |||||
using ElementBias = int32_t; \ | |||||
using ElementCompute = float; \ | |||||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; \ | |||||
switch (nonlinear_mode) { \ | |||||
case NonlineMode::IDENTITY: { \ | |||||
using EpilogueOp = cutlass::epilogue::thread:: \ | |||||
BiasAddLinearCombinationClamp< \ | |||||
ElementOutput, out_elements_per_access_, \ | |||||
ElementAccumulator, ElementBias, \ | |||||
ElementCompute>; \ | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ | |||||
delta + theta}; \ | |||||
RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||||
without_shared_load_); \ | |||||
} \ | |||||
case NonlineMode::RELU: { \ | |||||
using EpilogueOp = cutlass::epilogue::thread:: \ | |||||
BiasAddLinearCombinationReluClamp< \ | |||||
ElementOutput, out_elements_per_access_, \ | |||||
ElementAccumulator, ElementBias, \ | |||||
ElementCompute>; \ | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ | |||||
0, delta, theta}; \ | |||||
RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||||
without_shared_load_); \ | |||||
} \ | |||||
default: \ | |||||
megdnn_assert( \ | |||||
false, \ | |||||
"unsupported nonlinear mode for conv bias operator"); \ | |||||
} \ | |||||
} | |||||
#define DISPATCH_KERNEL \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 8, false); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, true); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, true); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, true); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 16, true); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 16, true); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 16, true); \ | |||||
megdnn_assert(false, \ | |||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
"(%dx%dx%d) and access_size (%d)", \ | |||||
threadblock_shape.m(), threadblock_shape.n(), \ | |||||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
warp_shape.k(), access_size); | |||||
DISPATCH_KERNEL; | |||||
#undef RUN_CUTLASS_WRAPPER | |||||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
#undef DISPATCH_KERNEL | |||||
} | |||||
#endif | |||||
#define INST(need_load_from_const_mem) \ | |||||
template void megdnn::cuda::cutlass_wrapper:: \ | |||||
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc< \ | |||||
need_load_from_const_mem>( \ | |||||
const uint8_t* d_src, const int8_t* d_filter, \ | |||||
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||||
int* workspace, const convolution::ConvParam& param, \ | |||||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
float gamma, float delta, float theta, float scale, \ | |||||
uint8_t src_zero_point, \ | |||||
const GemmCoord& threadblock_shape, \ | |||||
const GemmCoord& warp_shape, const int32_t access_size, \ | |||||
int stages, cudaStream_t stream); | |||||
INST(true); | |||||
INST(false); | |||||
#undef INST | |||||
// vim: syntax=cuda.doxygen |
@@ -1,804 +0,0 @@ | |||||
/** | |||||
* \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
// ignore warning of cutlass | |||||
#pragma GCC diagnostic push | |||||
#pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
#if !MEGDNN_TEGRA_X1 | |||||
#include "cutlass/convolution/device/convolution.h" | |||||
#endif | |||||
#include "src/common/opr_param_defs_enumv.cuh" | |||||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
#pragma GCC diagnostic pop | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
using namespace cutlass_wrapper; | |||||
/* ====== cutlass kernel wrapper for int8 nchw32 layout ====== */ | |||||
#if MEGDNN_TEGRA_X1 | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( | |||||
const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
int8_t* /* d_dst */, int* /* workspace */, | |||||
const convolution::ConvParam& /* param */, | |||||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
float /* beta */, float /* gamma */, float /* scale */, | |||||
const GemmCoord& /* threadblock_shape */, | |||||
const GemmCoord& /* warp_shape */, int /* stages */, | |||||
cudaStream_t /* stream */) {} | |||||
#else | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( | |||||
const int8_t* d_src, const int8_t* d_filter, | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
int* workspace, const convolution::ConvParam& param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float scale, const GemmCoord& threadblock_shape, | |||||
const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_, warp_m_, warp_n_, \ | |||||
warp_k_, stage_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | |||||
threadblock_shape.n() == threadblock_n_ && \ | |||||
threadblock_shape.k() == threadblock_k_ && \ | |||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
warp_shape.k() == warp_k_ && stages == stage_) { \ | |||||
using ThreadBlockShape = \ | |||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_>; \ | |||||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; \ | |||||
using Convolution = cutlass::conv::device::Convolution< \ | |||||
int8_t, cutlass::layout::TensorNCxHWx<32>, int8_t, \ | |||||
cutlass::layout::TensorCxRSKx<32>, ElementOutput, \ | |||||
cutlass::layout::TensorNCxHWx<32>, int32_t, \ | |||||
cutlass::layout::TensorNCxHWx<32>, int32_t, \ | |||||
cutlass::conv::ConvType::kConvolution, \ | |||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
cutlass::conv::threadblock:: \ | |||||
ConvolutionFpropTransThreadblockSwizzle, \ | |||||
stage_, 16, 16, NeedLoadFromConstMem, \ | |||||
cutlass::arch::OpMultiplyAddSaturate, \ | |||||
cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ | |||||
typename Convolution::ConvolutionParameter conv_param( \ | |||||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
return cutlass_convolution_wrapper<Convolution>( \ | |||||
d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ | |||||
epilogue, stream); \ | |||||
} | |||||
#define DISPATCH_KERNEL \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 1); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 1); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 32, 64, 32, 1); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 1); \ | |||||
megdnn_assert(false, \ | |||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
"(%dx%dx%d)", \ | |||||
threadblock_shape.m(), threadblock_shape.n(), \ | |||||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
warp_shape.k()); | |||||
using ElementOutput = int8_t; | |||||
using ElementAccumulator = int32_t; | |||||
using ElementBias = int32_t; | |||||
using ElementCompute = float; | |||||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
switch (nonlinear_mode) { | |||||
case NonlineMode::IDENTITY: { | |||||
using EpilogueOp = | |||||
cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::RELU: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationReluClamp< | |||||
ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::H_SWISH: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationHSwishClamp< | |||||
ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
default: | |||||
megdnn_assert(false, | |||||
"unsupported nonlinear mode for conv bias operator"); | |||||
} | |||||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
#undef DISPATCH_KERNEL | |||||
} | |||||
#endif | |||||
#define INST(need_load_from_const_mem) \ | |||||
template void megdnn::cuda::cutlass_wrapper:: \ | |||||
do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32< \ | |||||
need_load_from_const_mem>( \ | |||||
const int8_t* d_src, const int8_t* d_filter, \ | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
int* workspace, const convolution::ConvParam& param, \ | |||||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
float gamma, float scale, \ | |||||
const GemmCoord& threadblock_shape, \ | |||||
const GemmCoord& warp_shape, int stages, \ | |||||
cudaStream_t stream); | |||||
INST(true); | |||||
INST(false); | |||||
#undef INST | |||||
/* ===== cutlass kernel wrapper for int8 nchw32 layout and nchw4 output ===== */ | |||||
#if MEGDNN_TEGRA_X1 | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( | |||||
const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
int8_t* /* d_dst */, int* /* workspace */, | |||||
const convolution::ConvParam& /* param */, | |||||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
float /* beta */, float /* gamma */, float /* scale */, | |||||
const GemmCoord& /* threadblock_shape */, | |||||
const GemmCoord& /* warp_shape */, int /* stages */, | |||||
cudaStream_t /* stream */) {} | |||||
#else | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( | |||||
const int8_t* d_src, const int8_t* d_filter, | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
int* workspace, const convolution::ConvParam& param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float scale, const GemmCoord& threadblock_shape, | |||||
const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_, warp_m_, warp_n_, \ | |||||
warp_k_, stage_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | |||||
threadblock_shape.n() == threadblock_n_ && \ | |||||
threadblock_shape.k() == threadblock_k_ && \ | |||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
warp_shape.k() == warp_k_ && stages == stage_) { \ | |||||
using ThreadBlockShape = \ | |||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_>; \ | |||||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; \ | |||||
using Convolution = cutlass::conv::device::Convolution< \ | |||||
int8_t, cutlass::layout::TensorNCxHWx<32>, int8_t, \ | |||||
cutlass::layout::TensorCxRSKx<32>, ElementOutput, \ | |||||
cutlass::layout::TensorNCxHWx<4>, int32_t, \ | |||||
cutlass::layout::TensorNCxHWx<4>, int32_t, \ | |||||
cutlass::conv::ConvType::kConvolution, \ | |||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
cutlass::conv::threadblock:: \ | |||||
ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||||
stage_, 16, 16, NeedLoadFromConstMem>; \ | |||||
typename Convolution::ConvolutionParameter conv_param( \ | |||||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
return cutlass_convolution_wrapper<Convolution>( \ | |||||
d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ | |||||
epilogue, stream); \ | |||||
} | |||||
#define DISPATCH_KERNEL \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64, 2); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 1); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 1); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 32, 64, 32, 1); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 1); \ | |||||
megdnn_assert(false, \ | |||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
"(%dx%dx%d)", \ | |||||
threadblock_shape.m(), threadblock_shape.n(), \ | |||||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
warp_shape.k()); | |||||
using ElementOutput = int8_t; | |||||
using ElementAccumulator = int32_t; | |||||
using ElementBias = int32_t; | |||||
using ElementCompute = float; | |||||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
switch (nonlinear_mode) { | |||||
case NonlineMode::IDENTITY: { | |||||
using EpilogueOp = | |||||
cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
ElementOutput, 4, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::RELU: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationReluClamp< | |||||
ElementOutput, 4, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::H_SWISH: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationHSwishClamp< | |||||
ElementOutput, 4, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
default: | |||||
megdnn_assert(false, | |||||
"unsupported nonlinear mode for conv bias operator"); | |||||
} | |||||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
#undef DISPATCH_KERNEL | |||||
} | |||||
#endif | |||||
#define INST(need_load_from_const_mem) \ | |||||
template void megdnn::cuda::cutlass_wrapper:: \ | |||||
do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4< \ | |||||
need_load_from_const_mem>( \ | |||||
const int8_t* d_src, const int8_t* d_filter, \ | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
int* workspace, const convolution::ConvParam& param, \ | |||||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
float gamma, float scale, \ | |||||
const GemmCoord& threadblock_shape, \ | |||||
const GemmCoord& warp_shape, int stages, \ | |||||
cudaStream_t stream); | |||||
INST(true); | |||||
INST(false); | |||||
#undef INST | |||||
/* ====== cutlass kernel wrapper for int8 nchw4 layout ====== */ | |||||
#if MEGDNN_TEGRA_X1 | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||||
const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
int8_t* /* d_dst */, int* /* workspace */, | |||||
const convolution::ConvParam& /* param */, | |||||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
float /* beta */, float /* gamma */, float /* scale */, | |||||
const GemmCoord& /* threadblock_shape */, | |||||
const GemmCoord& /* warp_shape */, int /* stages */, | |||||
cudaStream_t /* stream */) {} | |||||
#else | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||||
const int8_t* d_src, const int8_t* d_filter, | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
int* workspace, const convolution::ConvParam& param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float scale, const GemmCoord& threadblock_shape, | |||||
const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_, warp_m_, warp_n_, \ | |||||
warp_k_, stage_, aligned_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | |||||
threadblock_shape.n() == threadblock_n_ && \ | |||||
threadblock_shape.k() == threadblock_k_ && \ | |||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
warp_shape.k() == warp_k_ && stages == stage_) { \ | |||||
using ThreadBlockShape = \ | |||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_>; \ | |||||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ | |||||
using Convolution = cutlass::conv::device::Convolution< \ | |||||
int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ | |||||
cutlass::layout::TensorCxRSKx<4>, ElementOutput, \ | |||||
cutlass::layout::TensorNCxHWx<4>, int32_t, \ | |||||
cutlass::layout::TensorNCxHWx<4>, int32_t, \ | |||||
cutlass::conv::ConvType::kConvolution, \ | |||||
cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \ | |||||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
cutlass::conv::threadblock:: \ | |||||
ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||||
stage_, 4, aligned_, NeedLoadFromConstMem, \ | |||||
cutlass::arch::OpMultiplyAdd>; \ | |||||
typename Convolution::ConvolutionParameter conv_param( \ | |||||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
return cutlass_convolution_wrapper<Convolution>( \ | |||||
d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ | |||||
epilogue, stream); \ | |||||
} | |||||
#define DISPATCH_KERNEL \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | |||||
megdnn_assert(false, \ | |||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
"(%dx%dx%d)", \ | |||||
threadblock_shape.m(), threadblock_shape.n(), \ | |||||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
warp_shape.k()); | |||||
using ElementOutput = int8_t; | |||||
using ElementAccumulator = int32_t; | |||||
using ElementBias = int32_t; | |||||
using ElementCompute = float; | |||||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
switch (nonlinear_mode) { | |||||
case NonlineMode::IDENTITY: { | |||||
using EpilogueOp = | |||||
cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
ElementOutput, 4, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::RELU: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationReluClamp< | |||||
ElementOutput, 4, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::H_SWISH: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationHSwishClamp< | |||||
ElementOutput, 4, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
default: | |||||
megdnn_assert(false, | |||||
"unsupported nonlinear mode for conv bias operator"); | |||||
} | |||||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
#undef DISPATCH_KERNEL | |||||
} | |||||
#endif | |||||
#define INST(need_load_from_const_mem) \ | |||||
template void megdnn::cuda::cutlass_wrapper:: \ | |||||
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4< \ | |||||
need_load_from_const_mem>( \ | |||||
const int8_t* d_src, const int8_t* d_filter, \ | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
int* workspace, const convolution::ConvParam& param, \ | |||||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
float gamma, float scale, \ | |||||
const GemmCoord& threadblock_shape, \ | |||||
const GemmCoord& warp_shape, int stages, \ | |||||
cudaStream_t stream); | |||||
INST(true); | |||||
INST(false); | |||||
#undef INST | |||||
/* ====== cutlass kernel wrapper for int8 nchw4 layout and nchw output ====== */ | |||||
#if MEGDNN_TEGRA_X1 | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw( | |||||
const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
const float* /* d_bias */, const float* /* d_z */, | |||||
float* /* d_dst */, int* /* workspace */, | |||||
const convolution::ConvParam& /* param */, | |||||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
float /* beta */, float /* gamma */, float /* scale */, | |||||
const GemmCoord& /* threadblock_shape */, | |||||
const GemmCoord& /* warp_shape */, int /* stages */, | |||||
cudaStream_t /* stream */) {} | |||||
#else | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw( | |||||
const int8_t* d_src, const int8_t* d_filter, | |||||
const float* d_bias, const float* d_z, float* d_dst, | |||||
int* workspace, const convolution::ConvParam& param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float scale, const GemmCoord& threadblock_shape, | |||||
const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_, warp_m_, warp_n_, \ | |||||
warp_k_, stages_, aligned_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | |||||
threadblock_shape.n() == threadblock_n_ && \ | |||||
threadblock_shape.k() == threadblock_k_ && \ | |||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
warp_shape.k() == warp_k_ && stages == stages_) { \ | |||||
using ThreadBlockShape = \ | |||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_>; \ | |||||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ | |||||
using Convolution = cutlass::conv::device::Convolution< \ | |||||
int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ | |||||
cutlass::layout::TensorCxRSKx<4>, ElementOutput, \ | |||||
cutlass::layout::TensorNCHW, float, \ | |||||
cutlass::layout::TensorNCHW, int32_t, \ | |||||
cutlass::conv::ConvType::kConvolution, \ | |||||
cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \ | |||||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
cutlass::conv::threadblock:: \ | |||||
ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||||
stages_, 4, aligned_, NeedLoadFromConstMem, \ | |||||
cutlass::arch::OpMultiplyAdd>; \ | |||||
typename Convolution::ConvolutionParameter conv_param( \ | |||||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
return cutlass_convolution_wrapper<Convolution>( \ | |||||
d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ | |||||
epilogue, stream); \ | |||||
} | |||||
#define DISPATCH_KERNEL \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | |||||
megdnn_assert(false, \ | |||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
"(%dx%dx%d)", \ | |||||
threadblock_shape.m(), threadblock_shape.n(), \ | |||||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
warp_shape.k()); | |||||
using ElementOutput = float; | |||||
using ElementAccumulator = int32_t; | |||||
using ElementBias = float; | |||||
using ElementCompute = float; | |||||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
switch (nonlinear_mode) { | |||||
case NonlineMode::IDENTITY: { | |||||
using EpilogueOp = | |||||
cutlass::epilogue::thread::BiasAddLinearCombination< | |||||
ElementOutput, 1, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::RELU: { | |||||
using EpilogueOp = | |||||
cutlass::epilogue::thread::BiasAddLinearCombinationRelu< | |||||
ElementOutput, 1, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::H_SWISH: { | |||||
using EpilogueOp = | |||||
cutlass::epilogue::thread::BiasAddLinearCombinationHSwish< | |||||
ElementOutput, 1, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
default: | |||||
megdnn_assert(false, | |||||
"unsupported nonlinear mode for conv bias operator"); | |||||
} | |||||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
#undef DISPATCH_KERNEL | |||||
} | |||||
#endif | |||||
#define INST(need_load_from_const_mem) \ | |||||
template void megdnn::cuda::cutlass_wrapper:: \ | |||||
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw< \ | |||||
need_load_from_const_mem>( \ | |||||
const int8_t* d_src, const int8_t* d_filter, \ | |||||
const float* d_bias, const float* d_z, float* d_dst, \ | |||||
int* workspace, const convolution::ConvParam& param, \ | |||||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
float gamma, float scale, \ | |||||
const GemmCoord& threadblock_shape, \ | |||||
const GemmCoord& warp_shape, int stages, \ | |||||
cudaStream_t stream); | |||||
INST(true); | |||||
INST(false); | |||||
#undef INST | |||||
/* ===== cutlass kernel wrapper for int8 nchw4 layout and nchw32 output ===== */ | |||||
#if MEGDNN_TEGRA_X1 | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32( | |||||
const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
int8_t* /* d_dst */, int* /* workspace */, | |||||
const convolution::ConvParam& /* param */, | |||||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
float /* beta */, float /* gamma */, float /* scale */, | |||||
const GemmCoord& /* threadblock_shape */, | |||||
const GemmCoord& /* warp_shape */, int /* stages */, | |||||
cudaStream_t /* stream */) {} | |||||
#else | |||||
template <bool NeedLoadFromConstMem> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32( | |||||
const int8_t* d_src, const int8_t* d_filter, | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
int* workspace, const convolution::ConvParam& param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float scale, const GemmCoord& threadblock_shape, | |||||
const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_, warp_m_, warp_n_, \ | |||||
warp_k_, stages_, aligned_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | |||||
threadblock_shape.n() == threadblock_n_ && \ | |||||
threadblock_shape.k() == threadblock_k_ && \ | |||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
warp_shape.k() == warp_k_ && stages == stages_) { \ | |||||
using ThreadBlockShape = \ | |||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_>; \ | |||||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ | |||||
using Convolution = cutlass::conv::device::Convolution< \ | |||||
int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ | |||||
cutlass::layout::TensorCxRSKx<4>, ElementOutput, \ | |||||
cutlass::layout::TensorNCxHWx<32>, int32_t, \ | |||||
cutlass::layout::TensorNCxHWx<32>, int32_t, \ | |||||
cutlass::conv::ConvType::kConvolution, \ | |||||
cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \ | |||||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
cutlass::conv::threadblock:: \ | |||||
ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||||
stages_, 4, aligned_, NeedLoadFromConstMem, \ | |||||
cutlass::arch::OpMultiplyAdd>; \ | |||||
typename Convolution::ConvolutionParameter conv_param( \ | |||||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
return cutlass_convolution_wrapper<Convolution>( \ | |||||
d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ | |||||
epilogue, stream); \ | |||||
} | |||||
#define DISPATCH_KERNEL \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | |||||
megdnn_assert(false, \ | |||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
"(%dx%dx%d)", \ | |||||
threadblock_shape.m(), threadblock_shape.n(), \ | |||||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
warp_shape.k()); | |||||
using ElementOutput = int8_t; | |||||
using ElementAccumulator = int32_t; | |||||
using ElementBias = int32_t; | |||||
using ElementCompute = float; | |||||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
switch (nonlinear_mode) { | |||||
case NonlineMode::IDENTITY: { | |||||
using EpilogueOp = | |||||
cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
ElementOutput, 4, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::RELU: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationReluClamp< | |||||
ElementOutput, 4, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::H_SWISH: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationHSwishClamp< | |||||
ElementOutput, 4, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
default: | |||||
megdnn_assert(false, | |||||
"unsupported nonlinear mode for conv bias operator"); | |||||
} | |||||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
#undef DISPATCH_KERNEL | |||||
} | |||||
#endif | |||||
#define INST(need_load_from_const_mem) \ | |||||
template void megdnn::cuda::cutlass_wrapper:: \ | |||||
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< \ | |||||
need_load_from_const_mem>( \ | |||||
const int8_t* d_src, const int8_t* d_filter, \ | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
int* workspace, const convolution::ConvParam& param, \ | |||||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
float gamma, float scale, \ | |||||
const GemmCoord& threadblock_shape, \ | |||||
const GemmCoord& warp_shape, int stages, \ | |||||
cudaStream_t stream); | |||||
INST(true); | |||||
INST(false); | |||||
#undef INST | |||||
/* ===== cutlass kernel wrapper for nchw4 layout and nhwc output ===== */ | |||||
#if MEGDNN_TEGRA_X1 | |||||
template <bool signedness> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( | |||||
const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
int8_t* /* d_dst */, int* /* workspace */, | |||||
const convolution::ConvParam& /* param */, | |||||
uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
float /* beta */, float /* gamma */, float /* delta */, | |||||
float /* theta */, float /* scale */, | |||||
const GemmCoord& /* threadblock_shape */, | |||||
const GemmCoord& /* warp_shape */, int /* stages */, | |||||
cudaStream_t /* stream */) {} | |||||
#else | |||||
template <bool signedness> | |||||
void megdnn::cuda::cutlass_wrapper:: | |||||
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( | |||||
const int8_t* d_src, const int8_t* d_filter, | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
int* workspace, const convolution::ConvParam& param, | |||||
uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
float delta, float theta, float scale, | |||||
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
int stages, cudaStream_t stream) { | |||||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_, warp_m_, warp_n_, \ | |||||
warp_k_, stages_, aligned_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | |||||
threadblock_shape.n() == threadblock_n_ && \ | |||||
threadblock_shape.k() == threadblock_k_ && \ | |||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
warp_shape.k() == warp_k_ && stages == stages_) { \ | |||||
using ThreadBlockShape = \ | |||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_>; \ | |||||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ | |||||
using Convolution = cutlass::conv::device::Convolution< \ | |||||
int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ | |||||
cutlass::layout::TensorCxRSKx<4>, ElementOutput, \ | |||||
cutlass::layout::TensorNHWC, int32_t, \ | |||||
cutlass::layout::TensorNHWC, int32_t, \ | |||||
cutlass::conv::ConvType::kConvolution, \ | |||||
cutlass::arch::OpClassSimt, cutlass::arch::Sm75, \ | |||||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
cutlass::conv::threadblock:: \ | |||||
ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||||
stages_, 4, aligned_, true, cutlass::arch::OpMultiplyAdd>; \ | |||||
typename Convolution::ConvolutionParameter conv_param( \ | |||||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
return cutlass_convolution_wrapper<Convolution>( \ | |||||
d_src, d_filter, d_bias, \ | |||||
reinterpret_cast<const ElementOutput*>(d_z), \ | |||||
reinterpret_cast<ElementOutput*>(d_dst), workspace, \ | |||||
conv_param, epilogue, stream); \ | |||||
} | |||||
#define DISPATCH_KERNEL \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | |||||
megdnn_assert(false, \ | |||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
"(%dx%dx%d)", \ | |||||
threadblock_shape.m(), threadblock_shape.n(), \ | |||||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
warp_shape.k()); | |||||
using ElementOutput = cutlass::integer_subbyte<4, signedness>; | |||||
using ElementAccumulator = int32_t; | |||||
using ElementBias = int32_t; | |||||
using ElementCompute = float; | |||||
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
switch (nonlinear_mode) { | |||||
case NonlineMode::IDENTITY: { | |||||
using EpilogueOp = | |||||
cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
delta + theta}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::RELU: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationReluClamp< | |||||
ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
0, delta, theta}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
case NonlineMode::H_SWISH: { | |||||
using EpilogueOp = cutlass::epilogue::thread:: | |||||
BiasAddLinearCombinationHSwishClamp< | |||||
ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
scale, delta, theta}; | |||||
DISPATCH_KERNEL; | |||||
} | |||||
default: | |||||
megdnn_assert(false, | |||||
"unsupported nonlinear mode for conv bias operator"); | |||||
} | |||||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
#undef DISPATCH_KERNEL | |||||
} | |||||
#endif | |||||
#define INST(signedness) \ | |||||
template void megdnn::cuda::cutlass_wrapper:: \ | |||||
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc<signedness>( \ | |||||
const int8_t* d_src, const int8_t* d_filter, \ | |||||
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
int* workspace, const convolution::ConvParam& param, \ | |||||
uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
float gamma, float delta, float theta, float scale, \ | |||||
const GemmCoord& threadblock_shape, \ | |||||
const GemmCoord& warp_shape, int stages, \ | |||||
cudaStream_t stream); | |||||
INST(true); | |||||
INST(false); | |||||
#undef INST | |||||
// vim: syntax=cuda.doxygen |
@@ -1,65 +0,0 @@ | |||||
/** | |||||
* \file | |||||
* dnn/src/cuda/conv_bias/int8/implicit_gemm_conv_bias_cutlass_wrapper.cuinl | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "cutlass/convolution/device/convolution.h" | |||||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
using namespace cutlass_wrapper; | |||||
template <typename Convolution> | |||||
void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper( | |||||
const typename Convolution::ElementSrc* d_src, | |||||
const typename Convolution::ElementFilter* d_filter, | |||||
const typename Convolution::ElementBias* d_bias, | |||||
const typename Convolution::ElementDst* d_z, | |||||
typename Convolution::ElementDst* d_dst, int* workspace, | |||||
typename Convolution::ConvolutionParameter const& conv_param, | |||||
typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
cudaStream_t stream, typename Convolution::ExtraParam extra_param) { | |||||
typename Convolution::TensorRefSrc tensor_src{ | |||||
const_cast<typename Convolution::ElementSrc*>(d_src), | |||||
Convolution::LayoutSrc::packed( | |||||
{conv_param.N, conv_param.H, conv_param.W, conv_param.C})}; | |||||
typename Convolution::TensorRefFilter tensor_filter{ | |||||
const_cast<typename Convolution::ElementFilter*>(d_filter), | |||||
Convolution::LayoutFilter::packed( | |||||
{conv_param.K, conv_param.R, conv_param.S, conv_param.C})}; | |||||
typename Convolution::TensorRefBias tensor_bias{ | |||||
const_cast<typename Convolution::ElementBias*>(d_bias), | |||||
Convolution::LayoutBias::packed({1, 1, 1, conv_param.K})}; | |||||
typename Convolution::TensorRefDst tensor_z{ | |||||
const_cast<typename Convolution::ElementDst*>(d_z), | |||||
Convolution::LayoutDst::packed( | |||||
{conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; | |||||
typename Convolution::TensorRefDst tensor_dst{ | |||||
d_dst, | |||||
Convolution::LayoutDst::packed( | |||||
{conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; | |||||
typename Convolution::Arguments arguments{conv_param, | |||||
tensor_src.non_const_ref(), | |||||
tensor_filter.non_const_ref(), | |||||
tensor_bias.non_const_ref(), | |||||
tensor_z.non_const_ref(), | |||||
tensor_dst.non_const_ref(), | |||||
epilogue, | |||||
{}, | |||||
{}, | |||||
extra_param}; | |||||
Convolution conv_op; | |||||
cutlass_check(conv_op.initialize(arguments, workspace)); | |||||
cutlass_check(conv_op(stream)); | |||||
after_kernel_launch(); | |||||
} | |||||
// vim: syntax=cuda.doxygen |
@@ -10,8 +10,7 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "./algo.h" | |||||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
#include "src/cuda/conv_bias/algo.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -81,29 +80,6 @@ ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::get_constants( | |||||
return {alpha, beta, gamma, delta, theta}; | return {alpha, beta, gamma, delta, theta}; | ||||
} | } | ||||
void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::do_exec( | |||||
const ExecArgs& args, void* filter_ptr, void* bias_ptr, void* z_ptr, | |||||
ConvParam kern_param, uint32_t nonlinear_mode, float alpha, float beta, | |||||
float gamma, float delta, float theta, cudaStream_t stream) const { | |||||
float dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale; | |||||
cutlass_wrapper::GemmCoord threadblock_shape{m_algo_param.threadblock_m, | |||||
m_algo_param.threadblock_n, | |||||
m_algo_param.threadblock_k}; | |||||
cutlass_wrapper::GemmCoord warp_shape{ | |||||
m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k}; | |||||
cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< | |||||
true>(reinterpret_cast<int8_t*>(args.src_tensor->raw_ptr), | |||||
reinterpret_cast<int8_t*>(filter_ptr), | |||||
reinterpret_cast<int32_t*>(bias_ptr), | |||||
reinterpret_cast<int8_t*>(z_ptr), | |||||
reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||||
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||||
threadblock_shape, warp_shape, m_algo_param.stage, stream); | |||||
} | |||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -10,8 +10,7 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "./algo.h" | |||||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
#include "src/cuda/conv_bias/algo.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -81,42 +80,6 @@ ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::get_constants( | |||||
return {alpha, beta, gamma, delta, theta}; | return {alpha, beta, gamma, delta, theta}; | ||||
} | } | ||||
void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||||
const ExecArgs& args, void* filter_ptr, void* bias_ptr, void* z_ptr, | |||||
ConvParam kern_param, uint32_t nonlinear_mode, float alpha, float beta, | |||||
float gamma, float delta, float theta, cudaStream_t stream) const { | |||||
float dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale; | |||||
cutlass_wrapper::GemmCoord threadblock_shape{m_algo_param.threadblock_m, | |||||
m_algo_param.threadblock_n, | |||||
m_algo_param.threadblock_k}; | |||||
cutlass_wrapper::GemmCoord warp_shape{ | |||||
m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k}; | |||||
if (kern_param.fh == 1 && kern_param.fw == 1) { | |||||
cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_nhwc<false>( | |||||
reinterpret_cast<int8_t*>(args.src_tensor->raw_ptr), | |||||
reinterpret_cast<int8_t*>(filter_ptr), | |||||
reinterpret_cast<int32_t*>(bias_ptr), | |||||
reinterpret_cast<int8_t*>(z_ptr), | |||||
reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||||
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||||
threadblock_shape, warp_shape, m_algo_param.access_size, | |||||
m_algo_param.stage, stream); | |||||
} else { | |||||
cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_nhwc<true>( | |||||
reinterpret_cast<int8_t*>(args.src_tensor->raw_ptr), | |||||
reinterpret_cast<int8_t*>(filter_ptr), | |||||
reinterpret_cast<int32_t*>(bias_ptr), | |||||
reinterpret_cast<int8_t*>(z_ptr), | |||||
reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||||
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||||
threadblock_shape, warp_shape, m_algo_param.access_size, | |||||
m_algo_param.stage, stream); | |||||
} | |||||
} | |||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -10,10 +10,9 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "./algo.h" | |||||
#include "src/common/conv_bias.h" | #include "src/common/conv_bias.h" | ||||
#include "src/cuda/conv_bias/algo.h" | |||||
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | ||||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
#include "src/cuda/conv_bias/reduce_filter.cuh" | #include "src/cuda/conv_bias/reduce_filter.cuh" | ||||
#include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
@@ -102,22 +101,40 @@ void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::exec( | |||||
if (args.z_layout->ndim > 0) | if (args.z_layout->ndim > 0) | ||||
z_ptr = args.z_tensor->raw_ptr; | z_ptr = args.z_tensor->raw_ptr; | ||||
// \note these constants of cutlass epilogue will be passed to method | |||||
// `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | |||||
// a different dtype here results in undefined epilogue behaviors | |||||
float alpha, beta, gamma, delta, theta; | float alpha, beta, gamma, delta, theta; | ||||
std::tie(alpha, beta, gamma, delta, theta) = get_constants(args); | std::tie(alpha, beta, gamma, delta, theta) = get_constants(args); | ||||
float dst_scale = 0.f; | |||||
float threshold = 0.f; | |||||
uint8_t src_zero = 0; | |||||
bool load_from_const = !(fh == 1 && fw == 1); | |||||
bool without_shared_load = true; | |||||
if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||||
dst_scale = | |||||
args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale; | |||||
src_zero = args.src_layout->dtype.param<dtype::Quantized4Asymm>() | |||||
.zero_point; | |||||
} else { // DTypeEnum::QuantizedS4 | |||||
dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale; | |||||
} | |||||
ConvParam kern_param; | |||||
kern_param.n = n, kern_param.co = co, kern_param.ci = ci, | |||||
kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, | |||||
kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, | |||||
kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, | |||||
kern_param.fw = fw; | |||||
cudaStream_t stream = cuda_stream(args.opr->handle()); | |||||
uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode); | |||||
const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop, | |||||
ConvType::kConvolution, | |||||
load_from_const, without_shared_load); | |||||
cudaStream_t stream = cuda_stream(args.opr->handle()); | |||||
execute_cutlass_conv_op(op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr, | |||||
z_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, | |||||
ci, co, fh, fw, ho, wo, ph, pw, sh, sw, dh, dw, | |||||
&alpha, &beta, &gamma, &delta, &theta, &threshold, | |||||
&dst_scale, stream, &src_zero); | |||||
do_exec(args, filter_ptr, bias_ptr, z_ptr, kern_param, nonlinear_mode, | |||||
alpha, beta, gamma, delta, theta, stream); | |||||
after_kernel_launch(); | |||||
} | } | ||||
std::string ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::to_string( | std::string ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::to_string( | ||||
@@ -10,10 +10,9 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "./algo.h" | |||||
#include "src/common/conv_bias.h" | #include "src/common/conv_bias.h" | ||||
#include "src/cuda/conv_bias/algo.h" | |||||
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | ||||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
#include "src/cuda/conv_bias/reduce_filter.cuh" | #include "src/cuda/conv_bias/reduce_filter.cuh" | ||||
#include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
@@ -109,22 +108,43 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::exec( | |||||
if (args.z_layout->ndim > 0) | if (args.z_layout->ndim > 0) | ||||
z_ptr = args.z_tensor->raw_ptr; | z_ptr = args.z_tensor->raw_ptr; | ||||
// \note these constants of cutlass epilogue will be passed to method | |||||
// `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | |||||
// a different dtype here results in undefined epilogue behaviors | |||||
float alpha, beta, gamma, delta, theta; | float alpha, beta, gamma, delta, theta; | ||||
std::tie(alpha, beta, gamma, delta, theta) = get_constants(args); | std::tie(alpha, beta, gamma, delta, theta) = get_constants(args); | ||||
float dst_scale = 0.f; | |||||
float threshold = 0.f; | |||||
uint8_t src_zero = 0; | |||||
bool load_from_const = !(fh == 1 && fw == 1); | |||||
bool without_shared_load = ((co % m_algo_param.threadblock_n == 0) && | |||||
(m_algo_param.threadblock_n == 32 || | |||||
m_algo_param.threadblock_n == 64)); | |||||
if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||||
dst_scale = | |||||
args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale; | |||||
src_zero = args.src_layout->dtype.param<dtype::Quantized4Asymm>() | |||||
.zero_point; | |||||
} else { // DTypeEnum::QuantizedS4 | |||||
dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale; | |||||
} | |||||
ConvParam kern_param; | |||||
kern_param.n = n, kern_param.co = co, kern_param.ci = ci, | |||||
kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, | |||||
kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, | |||||
kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, | |||||
kern_param.fw = fw; | |||||
cudaStream_t stream = cuda_stream(args.opr->handle()); | |||||
uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode); | |||||
const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop, | |||||
ConvType::kConvolution, | |||||
load_from_const, without_shared_load); | |||||
cudaStream_t stream = cuda_stream(args.opr->handle()); | |||||
execute_cutlass_conv_op(op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr, | |||||
z_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, | |||||
ci, co, fh, fw, ho, wo, ph, pw, sh, sw, dh, dw, | |||||
&alpha, &beta, &gamma, &delta, &theta, &threshold, | |||||
&dst_scale, stream, &src_zero); | |||||
do_exec(args, filter_ptr, bias_ptr, z_ptr, kern_param, nonlinear_mode, | |||||
alpha, beta, gamma, delta, theta, stream); | |||||
after_kernel_launch(); | |||||
} | } | ||||
std::string ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::to_string( | std::string ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::to_string( | ||||
@@ -10,12 +10,11 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "./algo.h" | |||||
#include "src/common/conv_bias.h" | |||||
#include "src/cuda/conv_bias/algo.h" | |||||
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | ||||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
#include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "src/common/conv_bias.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -38,8 +37,7 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::is_available( | |||||
bool available = true; | bool available = true; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
if (!check_bias_share_in_channel(*(args.bias_layout), | |||||
param.format)) | |||||
if (!check_bias_share_in_channel(*(args.bias_layout), param.format)) | |||||
return false; | return false; | ||||
if (param.format != Format::NCHW32 && param.format != Format::NCHW32_NCHW4) | if (param.format != Format::NCHW32 && param.format != Format::NCHW32_NCHW4) | ||||
return false; | return false; | ||||
@@ -137,19 +135,16 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
args.preprocessed_filter->tensors[0].raw_ptr); | args.preprocessed_filter->tensors[0].raw_ptr); | ||||
} | } | ||||
ConvParam kern_param; | |||||
kern_param.n = n, kern_param.co = co, kern_param.ci = ci, | |||||
kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, | |||||
kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, | |||||
kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, | |||||
kern_param.fw = fw; | |||||
float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale, | float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale, | ||||
filter_scale = | filter_scale = | ||||
args.filter_layout->dtype.param<dtype::QuantizedS8>().scale, | args.filter_layout->dtype.param<dtype::QuantizedS8>().scale, | ||||
bias_scale = | bias_scale = | ||||
args.bias_layout->dtype.param<dtype::QuantizedS32>().scale, | args.bias_layout->dtype.param<dtype::QuantizedS32>().scale, | ||||
dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS8>().scale; | dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS8>().scale; | ||||
// \note these constants of cutlass epilogue will be passed to method | |||||
// `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | |||||
// a different dtype here results in undefined epilogue behaviors | |||||
float alpha = src_scale * filter_scale / dst_scale, | float alpha = src_scale * filter_scale / dst_scale, | ||||
beta = bias_scale / dst_scale; | beta = bias_scale / dst_scale; | ||||
int8_t* z_dev_ptr = nullptr; | int8_t* z_dev_ptr = nullptr; | ||||
@@ -159,80 +154,20 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
float z_scale = args.z_layout->dtype.param<dtype::QuantizedS8>().scale; | float z_scale = args.z_layout->dtype.param<dtype::QuantizedS8>().scale; | ||||
gamma = z_scale / dst_scale; | gamma = z_scale / dst_scale; | ||||
} | } | ||||
uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode); | |||||
if (fh == 1 && fw == 1) { | |||||
if (param.format == Format::NCHW32) { | |||||
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32< | |||||
false>( | |||||
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, | |||||
args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr, | |||||
args.dst_tensor->compatible_ptr<int8_t>(), nullptr, | |||||
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||||
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, | |||||
m_algo_param.threadblock_n, | |||||
m_algo_param.threadblock_k}, | |||||
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||||
m_algo_param.warp_n, | |||||
m_algo_param.warp_k}, | |||||
m_algo_param.stage, stream); | |||||
} else { | |||||
megdnn_assert(param.format == Format::NCHW32_NCHW4); | |||||
cutlass_wrapper:: | |||||
do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4< | |||||
false>( | |||||
args.src_tensor->compatible_ptr<int8_t>(), | |||||
filter_ptr, | |||||
args.bias_tensor->compatible_ptr<int32_t>(), | |||||
z_dev_ptr, | |||||
args.dst_tensor->compatible_ptr<int8_t>(), nullptr, | |||||
kern_param, nonlinear_mode, alpha, beta, gamma, | |||||
dst_scale, | |||||
cutlass_wrapper::GemmCoord{ | |||||
m_algo_param.threadblock_m, | |||||
m_algo_param.threadblock_n, | |||||
m_algo_param.threadblock_k}, | |||||
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||||
m_algo_param.warp_n, | |||||
m_algo_param.warp_k}, | |||||
m_algo_param.stage, stream); | |||||
} | |||||
} else { | |||||
if (param.format == Format::NCHW32) { | |||||
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32< | |||||
true>( | |||||
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, | |||||
args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr, | |||||
args.dst_tensor->compatible_ptr<int8_t>(), nullptr, | |||||
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||||
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, | |||||
m_algo_param.threadblock_n, | |||||
m_algo_param.threadblock_k}, | |||||
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||||
m_algo_param.warp_n, | |||||
m_algo_param.warp_k}, | |||||
m_algo_param.stage, stream); | |||||
} else { | |||||
megdnn_assert(param.format == Format::NCHW32_NCHW4); | |||||
cutlass_wrapper:: | |||||
do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4< | |||||
true>( | |||||
args.src_tensor->compatible_ptr<int8_t>(), | |||||
filter_ptr, | |||||
args.bias_tensor->compatible_ptr<int32_t>(), | |||||
z_dev_ptr, | |||||
args.dst_tensor->compatible_ptr<int8_t>(), nullptr, | |||||
kern_param, nonlinear_mode, alpha, beta, gamma, | |||||
dst_scale, | |||||
cutlass_wrapper::GemmCoord{ | |||||
m_algo_param.threadblock_m, | |||||
m_algo_param.threadblock_n, | |||||
m_algo_param.threadblock_k}, | |||||
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||||
m_algo_param.warp_n, | |||||
m_algo_param.warp_k}, | |||||
m_algo_param.stage, stream); | |||||
} | |||||
} | |||||
float delta = 0.f, theta = 0.f, threshold = 0.f; | |||||
bool load_from_const = !(fh == 1 && fw == 1); | |||||
bool without_shared_load = (param.format == Format::NCHW32); | |||||
const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop, | |||||
ConvType::kConvolution, | |||||
load_from_const, without_shared_load); | |||||
execute_cutlass_conv_op( | |||||
op, args.src_tensor->raw_ptr, filter_ptr, args.bias_tensor->raw_ptr, | |||||
z_dev_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, | |||||
fw, ho, wo, ph, pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, | |||||
&theta, &threshold, &dst_scale, stream); | |||||
after_kernel_launch(); | after_kernel_launch(); | ||||
} | } | ||||
@@ -249,9 +184,8 @@ size_t ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm:: | |||||
return 0_z; | return 0_z; | ||||
} | } | ||||
SmallVector<TensorLayout> ConvBiasForwardImpl:: | |||||
AlgoInt8NCHW32IMMAImplicitGemm::deduce_preprocessed_filter_layout( | |||||
const SizeArgs& args) const { | |||||
SmallVector<TensorLayout> ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm:: | |||||
deduce_preprocessed_filter_layout(const SizeArgs& args) const { | |||||
return {args.filter_layout->collapse_contiguous()}; | return {args.filter_layout->collapse_contiguous()}; | ||||
} | } | ||||
@@ -6,14 +6,14 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "./algo.h" | |||||
#include "src/cuda/utils.h" | |||||
#include "src/cuda/convolution_helper/parameter.cuh" | |||||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
#include "src/common/conv_bias.h" | #include "src/common/conv_bias.h" | ||||
#include "src/cuda/conv_bias/algo.h" | |||||
#include "src/cuda/convolution_helper/parameter.cuh" | |||||
#include "src/cuda/utils.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -34,8 +34,7 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( | |||||
bool available = true; | bool available = true; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
if (!check_bias_share_in_channel(*(args.bias_layout), | |||||
param.format)) | |||||
if (!check_bias_share_in_channel(*(args.bias_layout), param.format)) | |||||
return false; | return false; | ||||
bool valid_format = param.format == Format::NCHW4_NCHW32 && | bool valid_format = param.format == Format::NCHW4_NCHW32 && | ||||
m_algo_param.threadblock_m % 32 == 0; | m_algo_param.threadblock_m % 32 == 0; | ||||
@@ -48,7 +47,8 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( | |||||
(args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || | (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || | ||||
args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm); | args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm); | ||||
valid_format |= param.format == Format::NCHW4; | valid_format |= param.format == Format::NCHW4; | ||||
if (!valid_format) return false; | |||||
if (!valid_format) | |||||
return false; | |||||
size_t n = args.src_layout->operator[](0), | size_t n = args.src_layout->operator[](0), | ||||
ci = args.src_layout->operator[](1) * 4, | ci = args.src_layout->operator[](1) * 4, | ||||
hi = args.src_layout->operator[](2), | hi = args.src_layout->operator[](2), | ||||
@@ -170,16 +170,13 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||||
args.preprocessed_filter->tensors[0].raw_ptr); | args.preprocessed_filter->tensors[0].raw_ptr); | ||||
} | } | ||||
convolution::ConvParam kern_param; | |||||
kern_param.n = n, kern_param.co = co, kern_param.ci = ci, | |||||
kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, | |||||
kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, | |||||
kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, | |||||
kern_param.fw = fw; | |||||
float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale, | float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale, | ||||
filter_scale = | filter_scale = | ||||
args.filter_layout->dtype.param<dtype::QuantizedS8>().scale; | args.filter_layout->dtype.param<dtype::QuantizedS8>().scale; | ||||
// \note these constants of cutlass epilogue will be passed to method | |||||
// `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | |||||
// a different dtype here results in undefined epilogue behaviors | |||||
float alpha = src_scale * filter_scale; | float alpha = src_scale * filter_scale; | ||||
float beta = 1.f; | float beta = 1.f; | ||||
float dst_scale = 1.f; | float dst_scale = 1.f; | ||||
@@ -192,13 +189,15 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||||
if (args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32) { | if (args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32) { | ||||
megdnn_assert(args.dst_layout->dtype.category() == | megdnn_assert(args.dst_layout->dtype.category() == | ||||
DTypeCategory::QUANTIZED); | DTypeCategory::QUANTIZED); | ||||
float bias_scale = args.bias_layout->dtype.param<dtype::QuantizedS32>() | |||||
.scale; | |||||
float bias_scale = | |||||
args.bias_layout->dtype.param<dtype::QuantizedS32>().scale; | |||||
dst_scale = get_scale(args.dst_layout->dtype); | dst_scale = get_scale(args.dst_layout->dtype); | ||||
alpha /= dst_scale, beta = bias_scale / dst_scale; | alpha /= dst_scale, beta = bias_scale / dst_scale; | ||||
} | } | ||||
float delta = 0.f; | float delta = 0.f; | ||||
void* z_ptr = nullptr; | |||||
if (args.z_layout->ndim > 0) { | if (args.z_layout->ndim > 0) { | ||||
z_ptr = args.z_tensor->raw_ptr; | |||||
gamma = 1.f; | gamma = 1.f; | ||||
if (args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) { | if (args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) { | ||||
megdnn_assert(args.dst_layout->dtype.category() == | megdnn_assert(args.dst_layout->dtype.category() == | ||||
@@ -213,98 +212,20 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||||
delta = -z_zero * gamma; | delta = -z_zero * gamma; | ||||
} | } | ||||
} | } | ||||
uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode); | |||||
bool nonunity_kernel = !(fh == 1 && fw == 1); | |||||
#define DISPATCH(_nonunity_kernel) \ | |||||
if (nonunity_kernel == _nonunity_kernel) { \ | |||||
cb(_nonunity_kernel) \ | |||||
} | |||||
if (param.format == Format::NCHW4) { | |||||
#define cb(_nonunity_kernel) \ | |||||
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4< \ | |||||
_nonunity_kernel>( \ | |||||
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \ | |||||
args.bias_tensor->compatible_ptr<int32_t>(), \ | |||||
args.z_tensor->compatible_ptr<int8_t>(), \ | |||||
args.dst_tensor->compatible_ptr<int8_t>(), nullptr, kern_param, \ | |||||
nonlinear_mode, alpha, beta, gamma, dst_scale, \ | |||||
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ | |||||
m_algo_param.threadblock_n, \ | |||||
m_algo_param.threadblock_k}, \ | |||||
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ | |||||
m_algo_param.warp_n, \ | |||||
m_algo_param.warp_k}, \ | |||||
m_algo_param.stage, stream); | |||||
DISPATCH(true); | |||||
DISPATCH(false); | |||||
#undef cb | |||||
} else if (param.format == Format::NCHW4_NCHW) { | |||||
#define cb(_nonunity_kernel) \ | |||||
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw< \ | |||||
_nonunity_kernel>( \ | |||||
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \ | |||||
args.bias_tensor->compatible_ptr<float>(), \ | |||||
args.z_tensor->compatible_ptr<float>(), \ | |||||
args.dst_tensor->compatible_ptr<float>(), nullptr, kern_param, \ | |||||
nonlinear_mode, alpha, beta, gamma, dst_scale, \ | |||||
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ | |||||
m_algo_param.threadblock_n, \ | |||||
m_algo_param.threadblock_k}, \ | |||||
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ | |||||
m_algo_param.warp_n, \ | |||||
m_algo_param.warp_k}, \ | |||||
m_algo_param.stage, stream); | |||||
DISPATCH(true); | |||||
DISPATCH(false); | |||||
#undef cb | |||||
} else if (param.format == Format::NCHW4_NHWC) { | |||||
#define cb(_signedness) \ | |||||
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc< \ | |||||
_signedness>( \ | |||||
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \ | |||||
args.bias_tensor->compatible_ptr<int32_t>(), \ | |||||
reinterpret_cast<int8_t*>(args.z_tensor->raw_ptr), \ | |||||
reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, \ | |||||
kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, \ | |||||
dst_scale, \ | |||||
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ | |||||
m_algo_param.threadblock_n, \ | |||||
m_algo_param.threadblock_k}, \ | |||||
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ | |||||
m_algo_param.warp_n, \ | |||||
m_algo_param.warp_k}, \ | |||||
m_algo_param.stage, stream); | |||||
if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4) { | |||||
cb(true); | |||||
} else { | |||||
megdnn_assert(args.dst_layout->dtype.enumv() == | |||||
DTypeEnum::Quantized4Asymm); | |||||
cb(false); | |||||
} | |||||
#undef cb | |||||
} else { | |||||
megdnn_assert(param.format == Format::NCHW4_NCHW32); | |||||
#define cb(_nonunity_kernel) \ | |||||
cutlass_wrapper:: \ | |||||
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< \ | |||||
_nonunity_kernel>( \ | |||||
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \ | |||||
args.bias_tensor->compatible_ptr<int32_t>(), \ | |||||
args.z_tensor->compatible_ptr<int8_t>(), \ | |||||
args.dst_tensor->compatible_ptr<int8_t>(), nullptr, \ | |||||
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, \ | |||||
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ | |||||
m_algo_param.threadblock_n, \ | |||||
m_algo_param.threadblock_k}, \ | |||||
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ | |||||
m_algo_param.warp_n, \ | |||||
m_algo_param.warp_k}, \ | |||||
m_algo_param.stage, stream); | |||||
DISPATCH(true); | |||||
DISPATCH(false); | |||||
#undef cb | |||||
#undef DISPATCH | |||||
} | |||||
float threshold = 0.f; | |||||
bool load_from_const = !(fh == 1 && fw == 1); | |||||
bool without_shared_load = false; | |||||
const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop, | |||||
ConvType::kConvolution, | |||||
load_from_const, without_shared_load); | |||||
execute_cutlass_conv_op( | |||||
op, args.src_tensor->raw_ptr, filter_ptr, args.bias_tensor->raw_ptr, | |||||
z_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, fw, | |||||
ho, wo, ph, pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, | |||||
&theta, &threshold, &dst_scale, stream); | |||||
after_kernel_launch(); | after_kernel_launch(); | ||||
} | } | ||||
@@ -10,8 +10,7 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "./algo.h" | |||||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
#include "src/cuda/conv_bias/algo.h" | |||||
#include "src/cuda/conv_bias/reduce_filter.cuh" | #include "src/cuda/conv_bias/reduce_filter.cuh" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
@@ -120,32 +119,15 @@ ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::get_constants( | |||||
delta = -z_zero * gamma; | delta = -z_zero * gamma; | ||||
} | } | ||||
return {alpha, beta, gamma, delta, theta}; | |||||
} | |||||
// identity epilogue has no theta: | |||||
// alpha * accumulator + beta * bias + gamma * source + delta | |||||
if (args.opr->param().nonlineMode == | |||||
param::ConvBias::NonlineMode::IDENTITY) { | |||||
delta += theta; | |||||
theta = 0.f; | |||||
} | |||||
void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::do_exec( | |||||
const ExecArgs& args, void* filter_ptr, void* bias_ptr, void* z_ptr, | |||||
ConvParam kern_param, uint32_t nonlinear_mode, float alpha, float beta, | |||||
float gamma, float delta, float theta, cudaStream_t stream) const { | |||||
float dst_scale = | |||||
args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale; | |||||
uint8_t src_zero = | |||||
args.src_layout->dtype.param<dtype::Quantized4Asymm>().zero_point; | |||||
cutlass_wrapper::GemmCoord threadblock_shape{m_algo_param.threadblock_m, | |||||
m_algo_param.threadblock_n, | |||||
m_algo_param.threadblock_k}; | |||||
cutlass_wrapper::GemmCoord warp_shape{ | |||||
m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k}; | |||||
cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< | |||||
true>(reinterpret_cast<uint8_t*>(args.src_tensor->raw_ptr), | |||||
reinterpret_cast<int8_t*>(filter_ptr), | |||||
reinterpret_cast<int32_t*>(bias_ptr), | |||||
reinterpret_cast<uint8_t*>(z_ptr), | |||||
reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||||
kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | |||||
dst_scale, src_zero, threadblock_shape, warp_shape, | |||||
m_algo_param.stage, stream); | |||||
return {alpha, beta, gamma, delta, theta}; | |||||
} | } | ||||
void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::update_bias( | void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::update_bias( | ||||
@@ -10,8 +10,7 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "./algo.h" | |||||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
#include "src/cuda/conv_bias/algo.h" | |||||
#include "src/cuda/conv_bias/reduce_filter.cuh" | #include "src/cuda/conv_bias/reduce_filter.cuh" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
@@ -121,44 +120,15 @@ ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::get_constants( | |||||
delta = -z_zero * gamma; | delta = -z_zero * gamma; | ||||
} | } | ||||
return {alpha, beta, gamma, delta, theta}; | |||||
} | |||||
void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||||
const ExecArgs& args, void* filter_ptr, void* bias_ptr, void* z_ptr, | |||||
ConvParam kern_param, uint32_t nonlinear_mode, float alpha, float beta, | |||||
float gamma, float delta, float theta, cudaStream_t stream) const { | |||||
float dst_scale = | |||||
args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale; | |||||
uint8_t src_zero = | |||||
args.src_layout->dtype.param<dtype::Quantized4Asymm>().zero_point; | |||||
cutlass_wrapper::GemmCoord threadblock_shape{m_algo_param.threadblock_m, | |||||
m_algo_param.threadblock_n, | |||||
m_algo_param.threadblock_k}; | |||||
cutlass_wrapper::GemmCoord warp_shape{ | |||||
m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k}; | |||||
if (kern_param.fh == 1 && kern_param.fw == 1) { | |||||
cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc<false>( | |||||
reinterpret_cast<uint8_t*>(args.src_tensor->raw_ptr), | |||||
reinterpret_cast<int8_t*>(filter_ptr), | |||||
reinterpret_cast<int32_t*>(bias_ptr), | |||||
reinterpret_cast<uint8_t*>(z_ptr), | |||||
reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||||
kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | |||||
dst_scale, src_zero, threadblock_shape, warp_shape, | |||||
m_algo_param.access_size, m_algo_param.stage, stream); | |||||
} else { | |||||
cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc<true>( | |||||
reinterpret_cast<uint8_t*>(args.src_tensor->raw_ptr), | |||||
reinterpret_cast<int8_t*>(filter_ptr), | |||||
reinterpret_cast<int32_t*>(bias_ptr), | |||||
reinterpret_cast<uint8_t*>(z_ptr), | |||||
reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||||
kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | |||||
dst_scale, src_zero, threadblock_shape, warp_shape, | |||||
m_algo_param.access_size, m_algo_param.stage, stream); | |||||
// identity epilogue has no theta: | |||||
// alpha * accumulator + beta * bias + gamma * source + delta | |||||
if (args.opr->param().nonlineMode == | |||||
param::ConvBias::NonlineMode::IDENTITY) { | |||||
delta += theta; | |||||
theta = 0.f; | |||||
} | } | ||||
return {alpha, beta, gamma, delta, theta}; | |||||
} | } | ||||
void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::update_bias( | void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::update_bias( | ||||
@@ -57,6 +57,7 @@ public: | |||||
class AlgoBatchedMatmul; | class AlgoBatchedMatmul; | ||||
class AlgoGroupConvGeneral; | class AlgoGroupConvGeneral; | ||||
class AlgoQUInt4x4x32WMMA; | class AlgoQUInt4x4x32WMMA; | ||||
class AlgoCutlassConvolutionBase; | |||||
class AlgoInt8CHWN4DotProdImplicitGemm; | class AlgoInt8CHWN4DotProdImplicitGemm; | ||||
class AlgoInt8NCHW4DotProdImplicitGemm; | class AlgoInt8NCHW4DotProdImplicitGemm; | ||||
class AlgoInt8CHWN4IMMAImplicitGemm; | class AlgoInt8CHWN4IMMAImplicitGemm; | ||||
@@ -1,100 +0,0 @@ | |||||
/** | |||||
* \file src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
// ignore warning of cutlass | |||||
#pragma GCC diagnostic push | |||||
#pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
#if !MEGDNN_TEGRA_X1 | |||||
#include "cutlass/convolution/device/convolution.h" | |||||
#endif | |||||
#include "src/common/opr_param_defs_enumv.cuh" | |||||
#include "src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh" | |||||
#pragma GCC diagnostic pop | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
using namespace cutlass_wrapper; | |||||
/* ================ cutlass kernel wrapper for nchw4 layout ================= */ | |||||
#if MEGDNN_TEGRA_X1 | |||||
void megdnn::cuda::cutlass_wrapper::do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||||
const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
int8_t* /* d_dst */, int* /* workspace */, | |||||
const convolution::ConvParam& /* param */, float /* alpha */, | |||||
const GemmCoord& /* threadblock_shape */, | |||||
const GemmCoord& /* warp_shape */, int /* stages */, | |||||
cudaStream_t /* stream */) {} | |||||
#else | |||||
void megdnn::cuda::cutlass_wrapper::do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||||
const int8_t* d_src, const int8_t* d_filter, int8_t* d_dst, | |||||
int* workspace, const convolution::ConvParam& param, float alpha, | |||||
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
int stages, cudaStream_t stream) { | |||||
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_, warp_m_, warp_n_, \ | |||||
warp_k_, stage_, aligned_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | |||||
threadblock_shape.n() == threadblock_n_ && \ | |||||
threadblock_shape.k() == threadblock_k_ && \ | |||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
warp_shape.k() == warp_k_ && stages == stage_) { \ | |||||
using ThreadBlockShape = \ | |||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_>; \ | |||||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ | |||||
using Deconvolution = cutlass::conv::device::Deconvolution< \ | |||||
int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ | |||||
cutlass::layout::TensorKxRSCx<4>, ElementOutput, \ | |||||
cutlass::layout::TensorNCxHWx<4>, int32_t, \ | |||||
cutlass::layout::TensorNCxHWx<4>, int32_t, \ | |||||
cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \ | |||||
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
cutlass::conv::threadblock:: \ | |||||
ConvolutionDgradNCxHWxThreadblockSwizzle, \ | |||||
stage_, 4, aligned_, true, cutlass::arch::OpMultiplyAdd>; \ | |||||
typename Deconvolution::ConvolutionParameter conv_param( \ | |||||
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
return cutlass_deconvolution_wrapper<Deconvolution>( \ | |||||
d_src, d_filter, nullptr, nullptr, d_dst, workspace, \ | |||||
conv_param, epilogue, stream); \ | |||||
} | |||||
#define DISPATCH_KERNEL \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 64, 16, 2, 4); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||||
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||||
megdnn_assert(false, \ | |||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
"(%dx%dx%d)", \ | |||||
threadblock_shape.m(), threadblock_shape.n(), \ | |||||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
warp_shape.k()); | |||||
using ElementOutput = int8_t; | |||||
using ElementAccumulator = int32_t; | |||||
using ElementBias = int32_t; | |||||
using ElementCompute = float; | |||||
using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
ElementOutput, 4, ElementAccumulator, ElementBias, ElementCompute>; | |||||
typename EpilogueOp::Params epilogue{alpha, 0, 0}; | |||||
DISPATCH_KERNEL; | |||||
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
#undef DISPATCH_KERNEL | |||||
} | |||||
#endif | |||||
// vim: syntax=cuda.doxygen |
@@ -1,44 +0,0 @@ | |||||
/** | |||||
* \file src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "cutlass/gemm/gemm.h" | |||||
#include "src/cuda/convolution_helper/parameter.cuh" | |||||
#include "src/cuda/utils.cuh" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace cutlass_wrapper { | |||||
using GemmCoord = cutlass::gemm::GemmCoord; | |||||
template <typename Convolution> | |||||
void cutlass_deconvolution_wrapper( | |||||
const typename Convolution::ElementSrc* d_src, | |||||
const typename Convolution::ElementFilter* d_filter, | |||||
const typename Convolution::ElementBias* d_bias, | |||||
const typename Convolution::ElementDst* d_z, | |||||
typename Convolution::ElementDst* d_dst, int* workspace, | |||||
typename Convolution::ConvolutionParameter const& conv_param, | |||||
typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
cudaStream_t stream); | |||||
void do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||||
const int8_t* d_src, const int8_t* d_filter, int8_t* d_dst, | |||||
int* workspace, const convolution::ConvParam& param, float alpha, | |||||
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
int stages, cudaStream_t stream); | |||||
} // namespace cutlass_wrapper | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cuda.doxygen |
@@ -1,62 +0,0 @@ | |||||
/** | |||||
* \file | |||||
* dnn/src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "cutlass/convolution/device/convolution.h" | |||||
#include "src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh" | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
using namespace cutlass_wrapper; | |||||
template <typename Deconvolution> | |||||
void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper( | |||||
const typename Deconvolution::ElementSrc* d_src, | |||||
const typename Deconvolution::ElementFilter* d_filter, | |||||
const typename Deconvolution::ElementBias* d_bias, | |||||
const typename Deconvolution::ElementDst* d_z, | |||||
typename Deconvolution::ElementDst* d_dst, int* workspace, | |||||
typename Deconvolution::ConvolutionParameter const& conv_param, | |||||
typename Deconvolution::EpilogueOutputOp::Params const& epilogue, | |||||
cudaStream_t stream) { | |||||
typename Deconvolution::TensorRefSrc tensor_src{ | |||||
const_cast<typename Deconvolution::ElementSrc*>(d_src), | |||||
Deconvolution::LayoutSrc::packed( | |||||
{conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; | |||||
typename Deconvolution::TensorRefFilter tensor_filter{ | |||||
const_cast<typename Deconvolution::ElementFilter*>(d_filter), | |||||
Deconvolution::LayoutFilter::packed( | |||||
{conv_param.K, conv_param.R, conv_param.S, conv_param.C})}; | |||||
typename Deconvolution::TensorRefBias tensor_bias{ | |||||
const_cast<typename Deconvolution::ElementBias*>(d_bias), | |||||
Deconvolution::LayoutBias::packed({1, 1, 1, conv_param.K})}; | |||||
typename Deconvolution::TensorRefDst tensor_z{ | |||||
const_cast<typename Deconvolution::ElementDst*>(d_z), | |||||
Deconvolution::LayoutDst::packed( | |||||
{conv_param.N, conv_param.H, conv_param.W, conv_param.C})}; | |||||
typename Deconvolution::TensorRefDst tensor_dst{ | |||||
d_dst, | |||||
Deconvolution::LayoutDst::packed( | |||||
{conv_param.N, conv_param.H, conv_param.W, conv_param.C})}; | |||||
typename Deconvolution::Arguments arguments{conv_param, | |||||
tensor_src.non_const_ref(), | |||||
tensor_filter.non_const_ref(), | |||||
tensor_bias.non_const_ref(), | |||||
tensor_z.non_const_ref(), | |||||
tensor_dst.non_const_ref(), | |||||
epilogue}; | |||||
Deconvolution deconv_op; | |||||
cutlass_check(deconv_op.initialize(arguments, workspace)); | |||||
cutlass_check(deconv_op(stream)); | |||||
after_kernel_launch(); | |||||
} | |||||
// vim: syntax=cuda.doxygen |
@@ -1,5 +1,6 @@ | |||||
/** | /** | ||||
* \file dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp | |||||
* \file | |||||
* dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,11 +11,11 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "./algo.h" | |||||
#include "src/cuda/utils.h" | |||||
#include "src/cuda/convolution_helper/parameter.cuh" | |||||
#include "src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh" | |||||
#include "src/cuda/convolution/backward_data/algo.h" | |||||
#include "src/cuda/convolution/backward_data/deconv_int8_helper.cuh" | #include "src/cuda/convolution/backward_data/deconv_int8_helper.cuh" | ||||
#include "src/cuda/convolution_helper/parameter.cuh" | |||||
#include "src/cuda/cutlass/singleton.h" | |||||
#include "src/cuda/utils.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -70,6 +71,7 @@ size_t ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm:: | |||||
void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | ||||
const ExecArgs& args) const { | const ExecArgs& args) const { | ||||
auto&& param = args.opr->param(); | |||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
size_t n = args.diff_layout->operator[](0), | size_t n = args.diff_layout->operator[](0), | ||||
co = args.diff_layout->operator[](1) * 4, | co = args.diff_layout->operator[](1) * 4, | ||||
@@ -81,6 +83,7 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||||
size_t fh = fm.spatial[0], fw = fm.spatial[1]; | size_t fh = fm.spatial[0], fw = fm.spatial[1]; | ||||
size_t sh = fm.stride[0], sw = fm.stride[1]; | size_t sh = fm.stride[0], sw = fm.stride[1]; | ||||
size_t ph = fm.padding[0], pw = fm.padding[1]; | size_t ph = fm.padding[0], pw = fm.padding[1]; | ||||
size_t dh = param.dilate_h, dw = param.dilate_w; | |||||
auto&& stream = cuda_stream(args.opr->handle()); | auto&& stream = cuda_stream(args.opr->handle()); | ||||
@@ -93,12 +96,6 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||||
filter_ptr, args.filter_tensor->compatible_ptr<int8_t>(), co, | filter_ptr, args.filter_tensor->compatible_ptr<int8_t>(), co, | ||||
ci, fh, fw, stream); | ci, fh, fw, stream); | ||||
} | } | ||||
convolution::ConvParam kern_param; | |||||
kern_param.n = n, kern_param.co = co, kern_param.ci = ci, | |||||
kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, | |||||
kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, | |||||
kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, | |||||
kern_param.fw = fw; | |||||
float diff_scale = | float diff_scale = | ||||
args.diff_layout->dtype.param<dtype::QuantizedS8>().scale, | args.diff_layout->dtype.param<dtype::QuantizedS8>().scale, | ||||
@@ -106,17 +103,60 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||||
args.filter_layout->dtype.param<dtype::QuantizedS8>().scale, | args.filter_layout->dtype.param<dtype::QuantizedS8>().scale, | ||||
grad_scale = | grad_scale = | ||||
args.grad_layout->dtype.param<dtype::QuantizedS8>().scale; | args.grad_layout->dtype.param<dtype::QuantizedS8>().scale; | ||||
float alpha = diff_scale * filter_scale / grad_scale; | |||||
cutlass_wrapper::do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||||
args.diff_tensor->compatible_ptr<int8_t>(), filter_ptr, | |||||
args.grad_tensor->compatible_ptr<int8_t>(), nullptr, kern_param, | |||||
alpha, | |||||
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, | |||||
m_algo_param.threadblock_n, | |||||
m_algo_param.threadblock_k}, | |||||
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n, | |||||
m_algo_param.warp_k}, | |||||
m_algo_param.stage, stream); | |||||
// \note these constants of cutlass epilogue will be passed to struct | |||||
// `ConvolutionArguments` by pointer and interpreted as ElementCompute*, a | |||||
// different dtype here results in undefined epilogue behaviors | |||||
float alpha = diff_scale * filter_scale / grad_scale, beta = 0.f, | |||||
gamma = 0.f, delta = 0.f; | |||||
using namespace cutlass::library; | |||||
// only use 16x64x8_16x64x8_2stages impl | |||||
ConvolutionKey key{ | |||||
cutlass::conv::Operator::kDgrad, | |||||
NumericTypeID::kS8, | |||||
LayoutTypeID::kTensorNC4HW4, | |||||
NumericTypeID::kS8, | |||||
LayoutTypeID::kTensorK4RSC4, | |||||
NumericTypeID::kS8, | |||||
LayoutTypeID::kTensorNC4HW4, | |||||
NumericTypeID::kS32, | |||||
LayoutTypeID::kTensorNC4HW4, | |||||
cutlass::conv::ConvType::kConvolution, | |||||
m_algo_param.threadblock_m, | |||||
m_algo_param.threadblock_n, | |||||
m_algo_param.threadblock_k, | |||||
m_algo_param.warp_m, | |||||
m_algo_param.warp_n, | |||||
m_algo_param.warp_k, | |||||
1, | |||||
1, | |||||
4, | |||||
cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp, | |||||
m_algo_param.stage, | |||||
true, | |||||
false}; | |||||
const Operation* op = Singleton::get().operation_table.find_op(key); | |||||
// gcc prints warnings when size_t values are implicitly narrowed to int | |||||
cutlass::conv::Conv2dProblemSize problem_size{ | |||||
int(n), int(hi), int(wi), int(ci), | |||||
int(co), int(fh), int(fw), int(ho), | |||||
int(wo), int(ph), int(pw), int(sh), | |||||
int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation}; | |||||
cutlass::library::ConvolutionArguments conv_args{ | |||||
problem_size, args.diff_tensor->compatible_ptr<int8_t>(), | |||||
filter_ptr, nullptr, | |||||
nullptr, args.grad_tensor->compatible_ptr<int8_t>(), | |||||
&alpha, &beta, | |||||
&gamma, &delta, | |||||
nullptr, nullptr, | |||||
nullptr, nullptr}; | |||||
cutlass_check(op->run(&conv_args, nullptr, stream)); | |||||
after_kernel_launch(); | after_kernel_launch(); | ||||
} | } | ||||
@@ -11,16 +11,16 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "./algo.h" | |||||
#include "src/cuda/utils.h" | |||||
#include "src/cuda/convolution/backward_data/algo.h" | |||||
#include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
#include "src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh" | |||||
#include "src/cuda/cutlass/singleton.h" | |||||
#include "src/cuda/utils.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
bool ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm:: | |||||
is_available(const SizeArgs& args) const { | |||||
bool ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::is_available( | |||||
const SizeArgs& args) const { | |||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
if (fm.format != Param::Format::NCHW) | if (fm.format != Param::Format::NCHW) | ||||
return false; | return false; | ||||
@@ -42,7 +42,8 @@ bool ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm:: | |||||
// TODO support group deconv int8 | // TODO support group deconv int8 | ||||
available &= (fm.group == 1); | available &= (fm.group == 1); | ||||
// ic and oc must be multiples of 4 | // ic and oc must be multiples of 4 | ||||
available &= ((fm.group * fm.icpg) % 4 == 0 && (fm.group * fm.ocpg) % 4 == 0); | |||||
available &= | |||||
((fm.group * fm.icpg) % 4 == 0 && (fm.group * fm.ocpg) % 4 == 0); | |||||
// mode must be cross correlation | // mode must be cross correlation | ||||
available &= !fm.should_flip; | available &= !fm.should_flip; | ||||
// mode must be 2D | // mode must be 2D | ||||
@@ -73,6 +74,7 @@ size_t ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm:: | |||||
void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( | void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( | ||||
const ExecArgs& args) const { | const ExecArgs& args) const { | ||||
auto&& param = args.opr->param(); | |||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
size_t n = args.diff_layout->operator[](0), | size_t n = args.diff_layout->operator[](0), | ||||
co = args.diff_layout->operator[](1), | co = args.diff_layout->operator[](1), | ||||
@@ -84,6 +86,7 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( | |||||
size_t fh = fm.spatial[0], fw = fm.spatial[1]; | size_t fh = fm.spatial[0], fw = fm.spatial[1]; | ||||
size_t sh = fm.stride[0], sw = fm.stride[1]; | size_t sh = fm.stride[0], sw = fm.stride[1]; | ||||
size_t ph = fm.padding[0], pw = fm.padding[1]; | size_t ph = fm.padding[0], pw = fm.padding[1]; | ||||
size_t dh = param.dilate_h, dw = param.dilate_w; | |||||
auto&& stream = cuda_stream(args.opr->handle()); | auto&& stream = cuda_stream(args.opr->handle()); | ||||
@@ -120,26 +123,63 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( | |||||
} | } | ||||
int8_t* inner_grad_ptr = reinterpret_cast<int8_t*>(bundle.get(2)); | int8_t* inner_grad_ptr = reinterpret_cast<int8_t*>(bundle.get(2)); | ||||
convolution::ConvParam kern_param; | |||||
kern_param.n = n, kern_param.co = co, kern_param.ci = ci, | |||||
kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, | |||||
kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, | |||||
kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, | |||||
kern_param.fw = fw; | |||||
float diff_scale = | float diff_scale = | ||||
args.diff_layout->dtype.param<dtype::QuantizedS8>().scale, | args.diff_layout->dtype.param<dtype::QuantizedS8>().scale, | ||||
filter_scale = | filter_scale = | ||||
args.filter_layout->dtype.param<dtype::QuantizedS8>().scale, | args.filter_layout->dtype.param<dtype::QuantizedS8>().scale, | ||||
grad_scale = | grad_scale = | ||||
args.grad_layout->dtype.param<dtype::QuantizedS8>().scale; | args.grad_layout->dtype.param<dtype::QuantizedS8>().scale; | ||||
float alpha = diff_scale * filter_scale / grad_scale; | |||||
// \note these constants of cutlass epilogue will be passed to struct | |||||
// `ConvolutionArguments` by pointer and interpreted as ElementCompute*, a | |||||
// different dtype here results in undefined epilogue behaviors | |||||
float alpha = diff_scale * filter_scale / grad_scale, beta = 0.f, | |||||
gamma = 0.f, delta = 0.f; | |||||
using namespace cutlass::library; | |||||
// only use 16x64x8_16x64x8_2stages impl | // only use 16x64x8_16x64x8_2stages impl | ||||
cutlass_wrapper::do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||||
inner_diff_ptr, inner_filter_ptr, inner_grad_ptr, nullptr, | |||||
kern_param, alpha, cutlass_wrapper::GemmCoord{16, 64, 8}, | |||||
cutlass_wrapper::GemmCoord{16, 64, 8}, 2, stream); | |||||
ConvolutionKey key{ | |||||
cutlass::conv::Operator::kDgrad, | |||||
NumericTypeID::kS8, | |||||
LayoutTypeID::kTensorNC4HW4, | |||||
NumericTypeID::kS8, | |||||
LayoutTypeID::kTensorK4RSC4, | |||||
NumericTypeID::kS8, | |||||
LayoutTypeID::kTensorNC4HW4, | |||||
NumericTypeID::kS32, | |||||
LayoutTypeID::kTensorNC4HW4, | |||||
cutlass::conv::ConvType::kConvolution, | |||||
16, | |||||
64, | |||||
8, | |||||
16, | |||||
64, | |||||
8, | |||||
1, | |||||
1, | |||||
4, | |||||
cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp, | |||||
2, | |||||
true, | |||||
false}; | |||||
const Operation* op = Singleton::get().operation_table.find_op(key); | |||||
// gcc prints warnings when size_t values are implicitly narrowed to int | |||||
cutlass::conv::Conv2dProblemSize problem_size{ | |||||
int(n), int(hi), int(wi), int(ci), | |||||
int(co), int(fh), int(fw), int(ho), | |||||
int(wo), int(ph), int(pw), int(sh), | |||||
int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation}; | |||||
cutlass::library::ConvolutionArguments conv_args{ | |||||
problem_size, inner_diff_ptr, inner_filter_ptr, nullptr, | |||||
nullptr, inner_grad_ptr, &alpha, &beta, | |||||
&gamma, &delta, nullptr, nullptr, | |||||
nullptr, nullptr}; | |||||
cutlass_check(op->run(&conv_args, nullptr, stream)); | |||||
after_kernel_launch(); | after_kernel_launch(); | ||||
@@ -0,0 +1,107 @@ | |||||
/*************************************************************************************************** | |||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
* | |||||
* Redistribution and use in source and binary forms, with or without | |||||
*modification, are permitted provided that the following conditions are met: | |||||
* * Redistributions of source code must retain the above copyright notice, | |||||
*this list of conditions and the following disclaimer. | |||||
* * Redistributions in binary form must reproduce the above copyright | |||||
*notice, this list of conditions and the following disclaimer in the | |||||
*documentation and/or other materials provided with the distribution. | |||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
*contributors may be used to endorse or promote products derived from this | |||||
*software without specific prior written permission. | |||||
* | |||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
* | |||||
**************************************************************************************************/ | |||||
/** | |||||
* \file dnn/src/cuda/cutlass/arch_mappings.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "cutlass/arch/arch.h" | |||||
#include "cutlass/arch/mma.h" | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
namespace cutlass { | |||||
namespace library { | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
template <typename ArchTag, typename OperatorClass> | |||||
struct ArchMap; | |||||
template <> | |||||
struct ArchMap<arch::Sm50, arch::OpClassSimt> { | |||||
static int const kMin = 50; | |||||
static int const kMax = 1024; | |||||
}; | |||||
template <> | |||||
struct ArchMap<arch::Sm60, arch::OpClassSimt> { | |||||
static int const kMin = 60; | |||||
static int const kMax = 1024; | |||||
}; | |||||
template <> | |||||
struct ArchMap<arch::Sm61, arch::OpClassSimt> { | |||||
static int const kMin = 61; | |||||
static int const kMax = 1024; | |||||
}; | |||||
template <> | |||||
struct ArchMap<arch::Sm70, arch::OpClassWmmaTensorOp> { | |||||
static int const kMin = 70; | |||||
static int const kMax = 1024; | |||||
}; | |||||
template <> | |||||
struct ArchMap<arch::Sm70, arch::OpClassTensorOp> { | |||||
static int const kMin = 70; | |||||
static int const kMax = 75; | |||||
}; | |||||
template <typename OperatorClass> | |||||
struct ArchMap<arch::Sm75, OperatorClass> { | |||||
static int const kMin = 75; | |||||
static int const kMax = 1024; | |||||
}; | |||||
template <typename OperatorClass> | |||||
struct ArchMap<arch::Sm80, OperatorClass> { | |||||
static int const kMin = 80; | |||||
static int const kMax = 1024; | |||||
}; | |||||
template <typename OperatorClass> | |||||
struct ArchMap<arch::Sm86, OperatorClass> { | |||||
static int const kMin = 86; | |||||
static int const kMax = 1024; | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
} // namespace library | |||||
} // namespace cutlass | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// |
@@ -0,0 +1,307 @@ | |||||
/*************************************************************************************************** | |||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. | |||||
* | |||||
* Redistribution and use in source and binary forms, with or without | |||||
*modification, are permitted provided that the following conditions are met: | |||||
* * Redistributions of source code must retain the above copyright notice, | |||||
*this list of conditions and the following disclaimer. | |||||
* * Redistributions in binary form must reproduce the above copyright | |||||
*notice, this list of conditions and the following disclaimer in the | |||||
*documentation and/or other materials provided with the distribution. | |||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
*contributors may be used to endorse or promote products derived from this | |||||
*software without specific prior written permission. | |||||
* | |||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
* | |||||
**************************************************************************************************/ | |||||
/** | |||||
* \file dnn/src/cuda/cutlass/convolution_operation.h | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "cutlass/convolution/device/convolution.h" | |||||
#include "src/cuda/cutlass/library_internal.h" | |||||
/////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
namespace cutlass { | |||||
namespace library { | |||||
/////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
template <typename Operator_> | |||||
class ConvolutionOperationBase : public Operation { | |||||
public: | |||||
using Operator = Operator_; | |||||
using ElementSrc = typename Operator::ElementSrc; | |||||
using LayoutSrc = typename Operator::LayoutSrc; | |||||
using ElementFilter = typename Operator::ElementFilter; | |||||
using LayoutFilter = typename Operator::LayoutFilter; | |||||
using ElementDst = typename Operator::ElementDst; | |||||
using LayoutDst = typename Operator::LayoutDst; | |||||
using ElementBias = typename Operator::ElementBias; | |||||
using LayoutBias = typename Operator::LayoutBias; | |||||
using ElementAccumulator = typename Operator::ElementAccumulator; | |||||
ConvolutionOperationBase(char const* name = "unknown_convolution") { | |||||
m_description.name = name; | |||||
m_description.provider = Provider::kCUTLASS; | |||||
m_description.kind = OperationKind::kConvolution; | |||||
m_description.conv_op = Operator::kConvolutionalOperator; | |||||
m_description.tile_description.threadblock_shape = make_Coord( | |||||
Operator::ThreadblockShape::kM, Operator::ThreadblockShape::kN, | |||||
Operator::ThreadblockShape::kK); | |||||
m_description.tile_description.threadblock_stages = Operator::kStages; | |||||
m_description.tile_description.warp_count = | |||||
make_Coord(Operator::ConvolutionKernel::WarpCount::kM, | |||||
Operator::ConvolutionKernel::WarpCount::kN, | |||||
Operator::ConvolutionKernel::WarpCount::kK); | |||||
m_description.tile_description.math_instruction.instruction_shape = | |||||
make_Coord(Operator::InstructionShape::kM, | |||||
Operator::InstructionShape::kN, | |||||
Operator::InstructionShape::kK); | |||||
m_description.tile_description.math_instruction.element_accumulator = | |||||
NumericTypeMap<ElementAccumulator>::kId; | |||||
m_description.tile_description.math_instruction.opcode_class = | |||||
OpcodeClassMap<typename Operator::OperatorClass>::kId; | |||||
m_description.tile_description.math_instruction.math_operation = | |||||
MathOperationMap<typename Operator::Operator>::kId; | |||||
m_description.tile_description.minimum_compute_capability = | |||||
ArchMap<typename Operator::ArchTag, | |||||
typename Operator::OperatorClass>::kMin; | |||||
m_description.tile_description.maximum_compute_capability = | |||||
ArchMap<typename Operator::ArchTag, | |||||
typename Operator::OperatorClass>::kMax; | |||||
m_description.src = make_TensorDescription<ElementSrc, LayoutSrc>( | |||||
Operator::kAlignmentSrc); | |||||
m_description.filter = | |||||
make_TensorDescription<ElementFilter, LayoutFilter>( | |||||
Operator::kAlignmentFilter); | |||||
m_description.dst = make_TensorDescription<ElementDst, LayoutDst>( | |||||
Operator::kAlignmentDst); | |||||
m_description.bias = make_TensorDescription<ElementBias, LayoutBias>( | |||||
Operator::kAlignmentDst); | |||||
m_description.convolution_type = Operator::kConvolutionType; | |||||
m_description.arch_tag = ArchTagMap<typename Operator::ArchTag>::kId; | |||||
m_description.epilogue_type = Operator::EpilogueOutputOp::kType; | |||||
m_description.epilogue_count = Operator::EpilogueOutputOp::kCount; | |||||
m_description.threadblock_swizzle = ThreadblockSwizzleMap< | |||||
typename Operator::ThreadblockSwizzle>::kId; | |||||
m_description.need_load_from_const_mem = | |||||
Operator::kNeedLoadFromConstMem; | |||||
m_description.gemm_mode = Operator::kGemmMode; | |||||
m_description.without_shared_load = Operator::kWithoutSharedLoad; | |||||
} | |||||
virtual OperationDescription const& description() const { | |||||
return m_description; | |||||
} | |||||
protected: | |||||
ConvolutionDescription m_description; | |||||
}; | |||||
/////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
namespace detail { | |||||
template <typename EpilogueOp, epilogue::EpilogueType type> | |||||
struct init_epilogue_param_; | |||||
template <typename EpilogueOp> | |||||
struct init_epilogue_param_<EpilogueOp, | |||||
epilogue::EpilogueType::kBiasAddLinearCombination> { | |||||
using ElementCompute = typename EpilogueOp::ElementCompute; | |||||
typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { | |||||
return {*static_cast<ElementCompute const*>(conv_args->alpha), | |||||
*static_cast<ElementCompute const*>(conv_args->beta), | |||||
*static_cast<ElementCompute const*>(conv_args->gamma), | |||||
*static_cast<ElementCompute const*>(conv_args->delta)}; | |||||
} | |||||
}; | |||||
template <typename EpilogueOp> | |||||
struct init_epilogue_param_< | |||||
EpilogueOp, epilogue::EpilogueType::kBiasAddLinearCombinationClamp> { | |||||
using ElementCompute = typename EpilogueOp::ElementCompute; | |||||
typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { | |||||
return {*static_cast<ElementCompute const*>(conv_args->alpha), | |||||
*static_cast<ElementCompute const*>(conv_args->beta), | |||||
*static_cast<ElementCompute const*>(conv_args->gamma), | |||||
*static_cast<ElementCompute const*>(conv_args->delta)}; | |||||
} | |||||
}; | |||||
template <typename EpilogueOp> | |||||
struct init_epilogue_param_< | |||||
EpilogueOp, epilogue::EpilogueType::kBiasAddLinearCombinationRelu> { | |||||
using ElementCompute = typename EpilogueOp::ElementCompute; | |||||
typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { | |||||
return {*static_cast<ElementCompute const*>(conv_args->alpha), | |||||
*static_cast<ElementCompute const*>(conv_args->beta), | |||||
*static_cast<ElementCompute const*>(conv_args->gamma), | |||||
*static_cast<ElementCompute const*>(conv_args->threshold), | |||||
*static_cast<ElementCompute const*>(conv_args->delta), | |||||
*static_cast<ElementCompute const*>(conv_args->theta)}; | |||||
} | |||||
}; | |||||
template <typename EpilogueOp> | |||||
struct init_epilogue_param_< | |||||
EpilogueOp, | |||||
epilogue::EpilogueType::kBiasAddLinearCombinationReluClamp> { | |||||
using ElementCompute = typename EpilogueOp::ElementCompute; | |||||
typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { | |||||
return {*static_cast<ElementCompute const*>(conv_args->alpha), | |||||
*static_cast<ElementCompute const*>(conv_args->beta), | |||||
*static_cast<ElementCompute const*>(conv_args->gamma), | |||||
*static_cast<ElementCompute const*>(conv_args->threshold), | |||||
*static_cast<ElementCompute const*>(conv_args->delta), | |||||
*static_cast<ElementCompute const*>(conv_args->theta)}; | |||||
} | |||||
}; | |||||
template <typename EpilogueOp> | |||||
struct init_epilogue_param_< | |||||
EpilogueOp, epilogue::EpilogueType::kBiasAddLinearCombinationHSwish> { | |||||
using ElementCompute = typename EpilogueOp::ElementCompute; | |||||
typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { | |||||
return {*static_cast<ElementCompute const*>(conv_args->alpha), | |||||
*static_cast<ElementCompute const*>(conv_args->beta), | |||||
*static_cast<ElementCompute const*>(conv_args->gamma), | |||||
*static_cast<ElementCompute const*>(conv_args->scale), | |||||
*static_cast<ElementCompute const*>(conv_args->delta), | |||||
*static_cast<ElementCompute const*>(conv_args->theta)}; | |||||
} | |||||
}; | |||||
template <typename EpilogueOp> | |||||
struct init_epilogue_param_< | |||||
EpilogueOp, | |||||
epilogue::EpilogueType::kBiasAddLinearCombinationHSwishClamp> { | |||||
using ElementCompute = typename EpilogueOp::ElementCompute; | |||||
typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { | |||||
return {*static_cast<ElementCompute const*>(conv_args->alpha), | |||||
*static_cast<ElementCompute const*>(conv_args->beta), | |||||
*static_cast<ElementCompute const*>(conv_args->gamma), | |||||
*static_cast<ElementCompute const*>(conv_args->scale), | |||||
*static_cast<ElementCompute const*>(conv_args->delta), | |||||
*static_cast<ElementCompute const*>(conv_args->theta)}; | |||||
} | |||||
}; | |||||
} // namespace detail | |||||
template <typename EpilogueOp> | |||||
struct init_epilogue_param | |||||
: public detail::init_epilogue_param_<EpilogueOp, EpilogueOp::kType> {}; | |||||
/////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
template <typename Operator_> | |||||
class ConvolutionOperation : public ConvolutionOperationBase<Operator_> { | |||||
public: | |||||
using Operator = Operator_; | |||||
using ElementSrc = typename Operator::ElementSrc; | |||||
using LayoutSrc = typename Operator::LayoutSrc; | |||||
using ElementFilter = typename Operator::ElementFilter; | |||||
using LayoutFilter = typename Operator::LayoutFilter; | |||||
using ElementBias = typename Operator::ElementBias; | |||||
using LayoutBias = typename Operator::LayoutBias; | |||||
using ElementDst = typename Operator::ElementDst; | |||||
using LayoutDst = typename Operator::LayoutDst; | |||||
using ElementAccumulator = typename Operator::ElementAccumulator; | |||||
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; | |||||
using OperatorArguments = typename Operator::Arguments; | |||||
ConvolutionOperation(char const* name = "unknown_gemm") | |||||
: ConvolutionOperationBase<Operator_>(name) {} | |||||
virtual Status run(void const* arguments_ptr, | |||||
void* device_workspace = nullptr, | |||||
cudaStream_t stream = nullptr) const { | |||||
cutlass::conv::Operator conv_op = this->m_description.conv_op; | |||||
ConvolutionArguments const* conv_args = | |||||
reinterpret_cast<ConvolutionArguments const*>(arguments_ptr); | |||||
const auto& ps = conv_args->problem_size; | |||||
OperatorArguments args; | |||||
args.problem_size = ps; | |||||
args.ref_src = { | |||||
static_cast<ElementSrc*>(const_cast<void*>(conv_args->src)), | |||||
LayoutSrc::packed(implicit_gemm_tensor_a_extent(conv_op, ps))}; | |||||
args.ref_filter = {static_cast<ElementFilter*>( | |||||
const_cast<void*>(conv_args->filter)), | |||||
LayoutFilter::packed( | |||||
implicit_gemm_tensor_b_extent(conv_op, ps))}; | |||||
args.ref_bias = { | |||||
static_cast<ElementBias*>(const_cast<void*>(conv_args->bias)), | |||||
LayoutBias::packed( | |||||
implicit_gemm_tensor_bias_extent(conv_op, ps))}; | |||||
args.ref_z = { | |||||
static_cast<ElementDst*>(const_cast<void*>(conv_args->z)), | |||||
LayoutDst::packed(implicit_gemm_tensor_c_extent(conv_op, ps))}; | |||||
args.ref_dst = { | |||||
static_cast<ElementDst*>(conv_args->dst), | |||||
LayoutDst::packed(implicit_gemm_tensor_c_extent(conv_op, ps))}; | |||||
args.output_op = | |||||
init_epilogue_param<typename Operator::EpilogueOutputOp>().get( | |||||
conv_args); | |||||
if (conv_args->extra_param) { | |||||
args.extra_param = | |||||
*reinterpret_cast<typename Operator::ExtraParam const*>( | |||||
conv_args->extra_param); | |||||
} | |||||
Operator op; | |||||
Status status = op.initialize(args, device_workspace); | |||||
if (status != Status::kSuccess) { | |||||
return status; | |||||
} | |||||
return op.run(stream); | |||||
} | |||||
}; | |||||
/////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
} // namespace library | |||||
} // namespace cutlass | |||||
/////////////////////////////////////////////////////////////////////////////////////////////////// |
@@ -0,0 +1,202 @@ | |||||
/*************************************************************************************************** | |||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. | |||||
* | |||||
* Redistribution and use in source and binary forms, with or without | |||||
*modification, are permitted provided that the following conditions are met: | |||||
* * Redistributions of source code must retain the above copyright notice, | |||||
*this list of conditions and the following disclaimer. | |||||
* * Redistributions in binary form must reproduce the above copyright | |||||
*notice, this list of conditions and the following disclaimer in the | |||||
*documentation and/or other materials provided with the distribution. | |||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
*contributors may be used to endorse or promote products derived from this | |||||
*software without specific prior written permission. | |||||
* | |||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
* | |||||
**************************************************************************************************/ | |||||
/** | |||||
* \file dnn/src/cuda/cutlass/gemm_operation.h | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "cutlass/gemm/device/gemm.h" | |||||
#include "src/cuda/cutlass/library_internal.h" | |||||
/////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
namespace cutlass { | |||||
namespace library { | |||||
/////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
/// Check whether Operator has member ReductionKernel using SFINAE (Substitution | |||||
/// Failure Is Not An Error) | |||||
template <typename Operator> | |||||
struct split_k_mode { | |||||
template <typename T> | |||||
static char check(typename T::ReductionKernel*); | |||||
template <typename T> | |||||
static int check(...); | |||||
SplitKMode operator()() { | |||||
if (sizeof(check<Operator>(0)) == sizeof(char)) { | |||||
// cutlass::gemm::device::GemmSplitKParallel | |||||
return SplitKMode::kParallel; | |||||
} else { | |||||
// cutlass::gemm::device::Gemm | |||||
return SplitKMode::kNone; | |||||
} | |||||
} | |||||
}; | |||||
/////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
template <typename Operator_> | |||||
class GemmOperationBase : public Operation { | |||||
public: | |||||
using Operator = Operator_; | |||||
using ElementA = typename Operator::ElementA; | |||||
using LayoutA = typename Operator::LayoutA; | |||||
using ElementB = typename Operator::ElementB; | |||||
using LayoutB = typename Operator::LayoutB; | |||||
using ElementC = typename Operator::ElementC; | |||||
using LayoutC = typename Operator::LayoutC; | |||||
using ElementAccumulator = typename Operator::ElementAccumulator; | |||||
GemmOperationBase(char const* name = "unknown_gemm") { | |||||
m_description.name = name; | |||||
m_description.provider = Provider::kCUTLASS; | |||||
m_description.kind = OperationKind::kGemm; | |||||
m_description.gemm_kind = GemmKind::kGemm; | |||||
m_description.tile_description.threadblock_shape = make_Coord( | |||||
Operator::ThreadblockShape::kM, Operator::ThreadblockShape::kN, | |||||
Operator::ThreadblockShape::kK); | |||||
m_description.tile_description.threadblock_stages = Operator::kStages; | |||||
m_description.tile_description.warp_count = | |||||
make_Coord(Operator::GemmKernel::WarpCount::kM, | |||||
Operator::GemmKernel::WarpCount::kN, | |||||
Operator::GemmKernel::WarpCount::kK); | |||||
m_description.tile_description.math_instruction.instruction_shape = | |||||
make_Coord(Operator::InstructionShape::kM, | |||||
Operator::InstructionShape::kN, | |||||
Operator::InstructionShape::kK); | |||||
m_description.tile_description.math_instruction.element_accumulator = | |||||
NumericTypeMap<ElementAccumulator>::kId; | |||||
m_description.tile_description.math_instruction.opcode_class = | |||||
OpcodeClassMap<typename Operator::OperatorClass>::kId; | |||||
m_description.tile_description.math_instruction.math_operation = | |||||
MathOperationMap<typename Operator::Operator>::kId; | |||||
m_description.tile_description.minimum_compute_capability = | |||||
ArchMap<typename Operator::ArchTag, | |||||
typename Operator::OperatorClass>::kMin; | |||||
m_description.tile_description.maximum_compute_capability = | |||||
ArchMap<typename Operator::ArchTag, | |||||
typename Operator::OperatorClass>::kMax; | |||||
m_description.A = make_TensorDescription<ElementA, LayoutA>( | |||||
Operator::kAlignmentA); | |||||
m_description.B = make_TensorDescription<ElementB, LayoutB>( | |||||
Operator::kAlignmentB); | |||||
m_description.C = make_TensorDescription<ElementC, LayoutC>( | |||||
Operator::kAlignmentC); | |||||
m_description.stages = Operator::kStages; | |||||
split_k_mode<Operator> mode; | |||||
m_description.split_k_mode = mode(); | |||||
} | |||||
virtual OperationDescription const& description() const { | |||||
return m_description; | |||||
} | |||||
protected: | |||||
GemmDescription m_description; | |||||
}; | |||||
/////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
template <typename Operator_> | |||||
class GemmOperation : public GemmOperationBase<Operator_> { | |||||
public: | |||||
using Operator = Operator_; | |||||
using ElementA = typename Operator::ElementA; | |||||
using LayoutA = typename Operator::LayoutA; | |||||
using ElementB = typename Operator::ElementB; | |||||
using LayoutB = typename Operator::LayoutB; | |||||
using ElementC = typename Operator::ElementC; | |||||
using LayoutC = typename Operator::LayoutC; | |||||
using ElementAccumulator = typename Operator::ElementAccumulator; | |||||
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; | |||||
using OperatorArguments = typename Operator::Arguments; | |||||
GemmOperation(char const* name = "unknown_gemm") | |||||
: GemmOperationBase<Operator_>(name) {} | |||||
virtual Status run(void const* arguments_ptr, | |||||
void* device_workspace = nullptr, | |||||
cudaStream_t stream = nullptr) const { | |||||
GemmArguments const* gemm_args = | |||||
reinterpret_cast<GemmArguments const*>(arguments_ptr); | |||||
OperatorArguments args; | |||||
args.problem_size = gemm_args->problem_size; | |||||
args.ref_A = {static_cast<ElementA const*>(gemm_args->A), | |||||
int(gemm_args->lda)}; | |||||
args.ref_B = {static_cast<ElementB const*>(gemm_args->B), | |||||
int(gemm_args->ldb)}; | |||||
args.ref_C = {static_cast<ElementC const*>(gemm_args->C), | |||||
int(gemm_args->ldc)}; | |||||
args.ref_D = {static_cast<ElementC*>(gemm_args->D), | |||||
int(gemm_args->ldd)}; | |||||
args.split_k_slices = gemm_args->split_k_slices; | |||||
args.epilogue = {*static_cast<ElementCompute const*>(gemm_args->alpha), | |||||
*static_cast<ElementCompute const*>(gemm_args->beta)}; | |||||
Operator op; | |||||
Status status = op.initialize(args, device_workspace); | |||||
if (status != Status::kSuccess) { | |||||
return status; | |||||
} | |||||
return op.run(stream); | |||||
} | |||||
}; | |||||
/////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
} // namespace library | |||||
} // namespace cutlass | |||||
/////////////////////////////////////////////////////////////////////////////////////////////////// |
@@ -0,0 +1,76 @@ | |||||
/*************************************************************************************************** | |||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
* | |||||
* Redistribution and use in source and binary forms, with or without | |||||
*modification, are permitted provided that the following conditions are met: | |||||
* * Redistributions of source code must retain the above copyright notice, | |||||
*this list of conditions and the following disclaimer. | |||||
* * Redistributions in binary form must reproduce the above copyright | |||||
*notice, this list of conditions and the following disclaimer in the | |||||
*documentation and/or other materials provided with the distribution. | |||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
*contributors may be used to endorse or promote products derived from this | |||||
*software without specific prior written permission. | |||||
* | |||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
* | |||||
**************************************************************************************************/ | |||||
/** | |||||
* \file dnn/src/cuda/cutlass/initialize_all.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/cuda/cutlass/manifest.h" | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
namespace cutlass { | |||||
namespace library { | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
#if __CUDACC_VER_MAJOR__ > 9 || \ | |||||
(__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
void initialize_all_gemm_simt_operations(Manifest& manifest); | |||||
void initialize_all_conv2d_simt_operations(Manifest& manifest); | |||||
void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest); | |||||
void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest); | |||||
void initialize_all_deconv_simt_operations(Manifest& manifest); | |||||
void initialize_all(Manifest& manifest) { | |||||
initialize_all_gemm_simt_operations(manifest); | |||||
initialize_all_conv2d_simt_operations(manifest); | |||||
initialize_all_conv2d_tensorop8816_operations(manifest); | |||||
initialize_all_conv2d_tensorop8832_operations(manifest); | |||||
initialize_all_deconv_simt_operations(manifest); | |||||
} | |||||
#else | |||||
void initialize_all(Manifest& manifest) {} | |||||
#endif | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
} // namespace library | |||||
} // namespace cutlass | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// |
@@ -0,0 +1,541 @@ | |||||
/*************************************************************************************************** | |||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
* | |||||
* Redistribution and use in source and binary forms, with or without | |||||
*modification, are permitted provided that the following conditions are met: | |||||
* * Redistributions of source code must retain the above copyright notice, | |||||
*this list of conditions and the following disclaimer. | |||||
* * Redistributions in binary form must reproduce the above copyright | |||||
*notice, this list of conditions and the following disclaimer in the | |||||
*documentation and/or other materials provided with the distribution. | |||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
*contributors may be used to endorse or promote products derived from this | |||||
*software without specific prior written permission. | |||||
* | |||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
* | |||||
**************************************************************************************************/ | |||||
/** | |||||
* \file dnn/src/cuda/cutlass/library.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
#include <cuda_runtime.h> | |||||
#include <cstdint> | |||||
#include <stdexcept> | |||||
#include <string> | |||||
#include <vector> | |||||
#pragma GCC diagnostic push | |||||
#pragma GCC diagnostic ignored "-Wreorder" | |||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
#pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
#include "cutlass/cutlass.h" | |||||
#include "cutlass/layout/tensor.h" | |||||
#include "cutlass/matrix_coord.h" | |||||
#include "cutlass/tensor_coord.h" | |||||
#include "cutlass/conv/conv2d_problem_size.h" | |||||
#include "cutlass/conv/convolution.h" | |||||
#include "cutlass/epilogue/epilogue.h" | |||||
#include "cutlass/gemm/gemm.h" | |||||
#pragma GCC diagnostic pop | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
namespace cutlass { | |||||
namespace library { | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
/// Layout type identifier | |||||
enum class LayoutTypeID { | |||||
kUnknown, | |||||
kColumnMajor, | |||||
kRowMajor, | |||||
kColumnMajorInterleavedK2, | |||||
kRowMajorInterleavedK2, | |||||
kColumnMajorInterleavedK4, | |||||
kRowMajorInterleavedK4, | |||||
kColumnMajorInterleavedK16, | |||||
kRowMajorInterleavedK16, | |||||
kColumnMajorInterleavedK32, | |||||
kRowMajorInterleavedK32, | |||||
kColumnMajorInterleavedK64, | |||||
kRowMajorInterleavedK64, | |||||
kTensorNCHW, | |||||
kTensorNCDHW, | |||||
kTensorNHWC, | |||||
kTensorNDHWC, | |||||
kTensorNC4HW4, | |||||
kTensorC4RSK4, | |||||
kTensorNC8HW8, | |||||
kTensorC8RSK8, | |||||
kTensorNC16HW16, | |||||
kTensorC16RSK16, | |||||
kTensorNC32HW32, | |||||
kTensorC32RSK32, | |||||
kTensorNC64HW64, | |||||
kTensorC64RSK64, | |||||
kTensorK4RSC4, | |||||
kInvalid | |||||
}; | |||||
/// Numeric data type | |||||
enum class NumericTypeID { | |||||
kUnknown, | |||||
kVoid, | |||||
kB1, | |||||
kU2, | |||||
kU4, | |||||
kU8, | |||||
kU16, | |||||
kU32, | |||||
kU64, | |||||
kS2, | |||||
kS4, | |||||
kS8, | |||||
kS16, | |||||
kS32, | |||||
kS64, | |||||
kF16, | |||||
kBF16, | |||||
kTF32, | |||||
kF32, | |||||
kF64, | |||||
kCF16, | |||||
kCBF16, | |||||
kCF32, | |||||
kCTF32, | |||||
kCF64, | |||||
kCS2, | |||||
kCS4, | |||||
kCS8, | |||||
kCS16, | |||||
kCS32, | |||||
kCS64, | |||||
kCU2, | |||||
kCU4, | |||||
kCU8, | |||||
kCU16, | |||||
kCU32, | |||||
kCU64, | |||||
kInvalid | |||||
}; | |||||
/// Enumerated type describing a transformation on a complex value. | |||||
enum class ComplexTransform { kNone, kConjugate, kInvalid }; | |||||
/// Providers | |||||
enum class Provider { | |||||
kNone, | |||||
kCUTLASS, | |||||
kReferenceHost, | |||||
kReferenceDevice, | |||||
kCUBLAS, | |||||
kCUDNN, | |||||
kInvalid | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
/// Enumeration indicating the kind of operation | |||||
enum class OperationKind { | |||||
kGemm, | |||||
kConv2d, | |||||
kConv3d, | |||||
kConvolution, | |||||
kEqGemm, | |||||
kSparseGemm, | |||||
kReduction, | |||||
kInvalid | |||||
}; | |||||
/// Enumeration indicating whether scalars are in host or device memory | |||||
enum class ScalarPointerMode { kHost, kDevice, kInvalid }; | |||||
/// Describes how reductions are performed across threadblocks | |||||
enum class SplitKMode { kNone, kSerial, kParallel, kParallelSerial, kInvalid }; | |||||
/// Indicates the classificaition of the math instruction | |||||
enum class OpcodeClassID { | |||||
kSimt, | |||||
kTensorOp, | |||||
kWmmaTensorOp, | |||||
kSparseTensorOp, | |||||
kInvalid | |||||
}; | |||||
enum class ArchTagID { | |||||
kSm50, | |||||
kSm60, | |||||
kSm61, | |||||
kSm70, | |||||
kSm72, | |||||
kSm75, | |||||
kSm80, | |||||
kSm86, | |||||
kInvalid | |||||
}; | |||||
enum class MathOperationID { | |||||
kAdd, | |||||
kMultiplyAdd, | |||||
kMultiplyAddSaturate, | |||||
kMultiplyAddFastBF16, | |||||
kMultiplyAddFastF16, | |||||
kMultiplyAddComplex, | |||||
kMultiplyAddGaussianComplex, | |||||
kXorPopc, | |||||
kInvalid | |||||
}; | |||||
enum class ThreadblockSwizzleID { | |||||
kGemmIdentity, | |||||
kGemmHorizontal, | |||||
kGemmBatchedIdentity, | |||||
kGemmSplitKIdentity, | |||||
kGemmSplitKHorizontal, | |||||
kGemvBatchedStridedDefault, | |||||
kGemvBatchedStridedReduction, | |||||
kConvolutionFpropCxRSKx, | |||||
kConvolutionDgradCxRSKx, | |||||
kConvolutionFpropNCxHWx, | |||||
kConvolutionFpropTrans, | |||||
kConvolutionDgradNCxHWx, | |||||
kInvalid | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
/// Enumeration indicating what kind of GEMM operation to perform | |||||
enum class GemmKind { | |||||
kGemm, | |||||
kSparse, | |||||
kUniversal, | |||||
kPlanarComplex, | |||||
kPlanarComplexArray, | |||||
kInvalid | |||||
}; | |||||
/// Mode of Universal GEMM | |||||
using GemmUniversalMode = cutlass::gemm::GemmUniversalMode; | |||||
/// Enumeration indicating what kind of Conv2d operation to perform | |||||
enum class ConvKind { kUnknown, kFprop, kDgrad, kWgrad, kInvalid }; | |||||
enum class ConvModeID { kCrossCorrelation, kConvolution, kInvalid }; | |||||
// Iterator algorithm enum in order of general performance-efficiency | |||||
enum class IteratorAlgorithmID { kNone, kAnalytic, kOptimized, kInvalid }; | |||||
enum class EpilogueKind { | |||||
kUnknown, | |||||
kBiasAddLinearCombination, | |||||
kBiasAddLinearCombinationClamp, | |||||
kBiasAddLInearCombinationHSwish, | |||||
kBiasAddLInearCombinationHSwishClamp, | |||||
kBiasAddLInearCombinationRelu, | |||||
kBiasAddLInearCombinationReluClamp, | |||||
kConversion, | |||||
kLinearCombination, | |||||
kLinearCombinationClamp, | |||||
kLinearCombinationPlanarComplex, | |||||
kLinearCombinationRelu, | |||||
kLinearCombinationSigmoid, | |||||
kInvalid | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
struct MathInstructionDescription { | |||||
/// Shape of the target math instruction | |||||
cutlass::gemm::GemmCoord instruction_shape; | |||||
/// Describes the data type of the internal accumulator | |||||
NumericTypeID element_accumulator; | |||||
/// Classification of math instruction | |||||
OpcodeClassID opcode_class; | |||||
/// Type of math operation performed | |||||
MathOperationID math_operation; | |||||
// | |||||
// Methods | |||||
// | |||||
MathInstructionDescription( | |||||
cutlass::gemm::GemmCoord instruction_shape = | |||||
cutlass::gemm::GemmCoord(), | |||||
NumericTypeID element_accumulator = NumericTypeID::kInvalid, | |||||
OpcodeClassID opcode_class = OpcodeClassID::kInvalid, | |||||
MathOperationID math_operation = MathOperationID::kMultiplyAdd) | |||||
: instruction_shape(instruction_shape), | |||||
element_accumulator(element_accumulator), | |||||
opcode_class(opcode_class), | |||||
math_operation(math_operation) {} | |||||
// Equality operator | |||||
inline bool operator==(MathInstructionDescription const& rhs) const { | |||||
return ((instruction_shape == rhs.instruction_shape) && | |||||
(element_accumulator == rhs.element_accumulator) && | |||||
(opcode_class == rhs.opcode_class) && | |||||
(math_operation == rhs.math_operation)); | |||||
} | |||||
// Inequality operator | |||||
inline bool operator!=(MathInstructionDescription const& rhs) const { | |||||
return !(*this == rhs); | |||||
} | |||||
}; | |||||
/// Structure describing the tiled structure of a GEMM-like computation | |||||
struct TileDescription { | |||||
/// Describes the shape of a threadblock (in elements) | |||||
cutlass::gemm::GemmCoord threadblock_shape; | |||||
/// Describes the number of pipeline stages in the threadblock-scoped | |||||
/// mainloop | |||||
int threadblock_stages; | |||||
/// Number of warps in each logical dimension | |||||
cutlass::gemm::GemmCoord warp_count; | |||||
/// Core math instruction | |||||
MathInstructionDescription math_instruction; | |||||
/// Minimum compute capability (e.g. 70, 75) of a device eligible to run the | |||||
/// operation. | |||||
int minimum_compute_capability; | |||||
/// Minimum compute capability (e.g. 70, 75) of a device eligible to run the | |||||
/// operation. | |||||
int maximum_compute_capability; | |||||
// | |||||
// Methods | |||||
// | |||||
TileDescription( | |||||
cutlass::gemm::GemmCoord threadblock_shape = | |||||
cutlass::gemm::GemmCoord(), | |||||
int threadblock_stages = 0, | |||||
cutlass::gemm::GemmCoord warp_count = cutlass::gemm::GemmCoord(), | |||||
MathInstructionDescription math_instruction = | |||||
MathInstructionDescription(), | |||||
int minimum_compute_capability = 0, | |||||
int maximum_compute_capability = 0) | |||||
: threadblock_shape(threadblock_shape), | |||||
threadblock_stages(threadblock_stages), | |||||
warp_count(warp_count), | |||||
math_instruction(math_instruction), | |||||
minimum_compute_capability(minimum_compute_capability), | |||||
maximum_compute_capability(maximum_compute_capability) {} | |||||
// Equality operator | |||||
inline bool operator==(TileDescription const& rhs) const { | |||||
return ((threadblock_shape == rhs.threadblock_shape) && | |||||
(threadblock_stages == rhs.threadblock_stages) && | |||||
(warp_count == rhs.warp_count) && | |||||
(math_instruction == rhs.math_instruction) && | |||||
(minimum_compute_capability == | |||||
rhs.minimum_compute_capability) && | |||||
(maximum_compute_capability == rhs.maximum_compute_capability)); | |||||
} | |||||
// Inequality operator | |||||
inline bool operator!=(TileDescription const& rhs) const { | |||||
return !(*this == rhs); | |||||
} | |||||
}; | |||||
/// High-level description of an operation | |||||
struct OperationDescription { | |||||
/// Unique identifier describing the operation | |||||
char const* name; | |||||
/// Operation provider | |||||
Provider provider; | |||||
/// Kind of operation | |||||
OperationKind kind; | |||||
/// Describes the tiled structure of a GEMM-like computation | |||||
TileDescription tile_description; | |||||
// | |||||
// Methods | |||||
// | |||||
OperationDescription( | |||||
char const* name = "unknown", | |||||
OperationKind kind = OperationKind::kInvalid, | |||||
TileDescription const& tile_description = TileDescription()) | |||||
: name(name), kind(kind), tile_description(tile_description) {} | |||||
}; | |||||
/// Structure describing the properties of a tensor | |||||
struct TensorDescription { | |||||
/// Numeric type of an individual element | |||||
NumericTypeID element; | |||||
/// Enumerant identifying the layout function for the tensor | |||||
LayoutTypeID layout; | |||||
/// Alignment restriction on pointers, strides, and extents | |||||
int alignment; | |||||
/// log2() of the maximum extent of each dimension | |||||
int log_extent_range; | |||||
/// log2() of the maximum value each relevant stride may have | |||||
int log_stride_range; | |||||
// | |||||
// Methods | |||||
// | |||||
TensorDescription(NumericTypeID element = NumericTypeID::kInvalid, | |||||
LayoutTypeID layout = LayoutTypeID::kInvalid, | |||||
int alignment = 1, int log_extent_range = 24, | |||||
int log_stride_range = 24) | |||||
: element(element), | |||||
layout(layout), | |||||
alignment(alignment), | |||||
log_extent_range(log_extent_range), | |||||
log_stride_range(log_stride_range) {} | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
struct GemmDescription : public OperationDescription { | |||||
GemmKind gemm_kind; | |||||
TensorDescription A; | |||||
TensorDescription B; | |||||
TensorDescription C; | |||||
int stages; | |||||
SplitKMode split_k_mode; | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
struct GemmArguments { | |||||
/// GEMM problem size | |||||
gemm::GemmCoord problem_size; | |||||
/// Device pointers to input and output matrices | |||||
void const* A; | |||||
void const* B; | |||||
void const* C; | |||||
void* D; | |||||
/// Leading dimensions of input and output matrices | |||||
int64_t lda; | |||||
int64_t ldb; | |||||
int64_t ldc; | |||||
int64_t ldd; | |||||
/// Number of partitions of K dimension | |||||
int split_k_slices; | |||||
/// Host or device pointers to epilogue scalars, note that these pointers | |||||
/// will be interpreted as ElementCompute* in method `op->run(args)`, a | |||||
/// different dtype here results in undefined epilogue behaviors | |||||
void const* alpha; | |||||
void const* beta; | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
struct ConvolutionDescription : public OperationDescription { | |||||
conv::Operator conv_op; | |||||
TensorDescription src; | |||||
TensorDescription filter; | |||||
TensorDescription dst; | |||||
TensorDescription bias; | |||||
conv::ConvType convolution_type; | |||||
ArchTagID arch_tag; | |||||
epilogue::EpilogueType epilogue_type; | |||||
int epilogue_count; | |||||
ThreadblockSwizzleID threadblock_swizzle; | |||||
bool need_load_from_const_mem; | |||||
conv::ImplicitGemmMode gemm_mode; | |||||
bool without_shared_load; | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
struct ConvolutionArguments { | |||||
/// Problem size | |||||
conv::Conv2dProblemSize problem_size; | |||||
/// Device pointers to input and output tensors | |||||
void const* src; | |||||
void const* filter; | |||||
void const* bias; | |||||
void const* z; | |||||
void* dst; | |||||
/// Host or device pointers to epilogue scalars, note that these pointers | |||||
/// will be interpreted as ElementCompute* in method `op->run(args)`, a | |||||
/// different dtype here results in undefined epilogue behaviors | |||||
void const* alpha; | |||||
void const* beta; | |||||
void const* gamma; | |||||
void const* delta; | |||||
void const* theta; | |||||
void const* threshold; | |||||
void const* scale; | |||||
/// Host pointer to extra param struct | |||||
void const* extra_param; | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
/// Base class for all operations | |||||
class Operation { | |||||
public: | |||||
virtual ~Operation() {} | |||||
virtual OperationDescription const& description() const = 0; | |||||
virtual Status run(void const* arguments, void* device_workspace = nullptr, | |||||
cudaStream_t stream = nullptr) const = 0; | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
} // namespace library | |||||
} // namespace cutlass | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// |
@@ -0,0 +1,580 @@ | |||||
/*************************************************************************************************** | |||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. | |||||
* | |||||
* Redistribution and use in source and binary forms, with or without | |||||
*modification, are permitted provided that the following conditions are met: | |||||
* * Redistributions of source code must retain the above copyright notice, | |||||
*this list of conditions and the following disclaimer. | |||||
* * Redistributions in binary form must reproduce the above copyright | |||||
*notice, this list of conditions and the following disclaimer in the | |||||
*documentation and/or other materials provided with the distribution. | |||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
*contributors may be used to endorse or promote products derived from this | |||||
*software without specific prior written permission. | |||||
* | |||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
* | |||||
**************************************************************************************************/ | |||||
/** | |||||
* \file dnn/src/cuda/cutlass/library_internal.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#pragma GCC diagnostic push | |||||
#pragma GCC diagnostic ignored "-Wreorder" | |||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
#pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
#include "cutlass/arch/arch.h" | |||||
#include "cutlass/arch/mma.h" | |||||
#include "cutlass/complex.h" | |||||
#include "cutlass/convolution/threadblock/threadblock_swizzle.h" | |||||
#include "cutlass/cutlass.h" | |||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h" | |||||
#include "cutlass/layout/matrix.h" | |||||
#include "cutlass/numeric_types.h" | |||||
#pragma GCC diagnostic pop | |||||
#include "src/cuda/cutlass/arch_mappings.h" | |||||
#include "src/cuda/cutlass/library.h" | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
namespace cutlass { | |||||
namespace library { | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
template <typename T> | |||||
struct NumericTypeMap; | |||||
template <> | |||||
struct NumericTypeMap<cutlass::uint1b_t> { | |||||
static NumericTypeID const kId = NumericTypeID::kB1; | |||||
}; | |||||
template <> | |||||
struct NumericTypeMap<cutlass::int4b_t> { | |||||
static NumericTypeID const kId = NumericTypeID::kS4; | |||||
}; | |||||
template <> | |||||
struct NumericTypeMap<int8_t> { | |||||
static NumericTypeID const kId = NumericTypeID::kS8; | |||||
}; | |||||
template <> | |||||
struct NumericTypeMap<int16_t> { | |||||
static NumericTypeID const kId = NumericTypeID::kS16; | |||||
}; | |||||
template <> | |||||
struct NumericTypeMap<int32_t> { | |||||
static NumericTypeID const kId = NumericTypeID::kS32; | |||||
}; | |||||
template <> | |||||
struct NumericTypeMap<int64_t> { | |||||
static NumericTypeID const kId = NumericTypeID::kS64; | |||||
}; | |||||
template <> | |||||
struct NumericTypeMap<cutlass::uint4b_t> { | |||||
static NumericTypeID const kId = NumericTypeID::kU4; | |||||
}; | |||||
template <> | |||||
struct NumericTypeMap<uint8_t> { | |||||
static NumericTypeID const kId = NumericTypeID::kU8; | |||||
}; | |||||
template <> | |||||
struct NumericTypeMap<uint16_t> { | |||||
static NumericTypeID const kId = NumericTypeID::kU16; | |||||
}; | |||||
template <> | |||||
struct NumericTypeMap<uint32_t> { | |||||
static NumericTypeID const kId = NumericTypeID::kU32; | |||||
}; | |||||
template <> | |||||
struct NumericTypeMap<uint64_t> { | |||||
static NumericTypeID const kId = NumericTypeID::kU64; | |||||
}; | |||||
template <> | |||||
struct NumericTypeMap<cutlass::half_t> { | |||||
static NumericTypeID const kId = NumericTypeID::kF16; | |||||
}; | |||||
template <> | |||||
struct NumericTypeMap<float> { | |||||
static NumericTypeID const kId = NumericTypeID::kF32; | |||||
}; | |||||
template <> | |||||
struct NumericTypeMap<double> { | |||||
static NumericTypeID const kId = NumericTypeID::kF64; | |||||
}; | |||||
template <> | |||||
struct NumericTypeMap<cutlass::complex<cutlass::half_t>> { | |||||
static NumericTypeID const kId = NumericTypeID::kCF16; | |||||
}; | |||||
template <> | |||||
struct NumericTypeMap<cutlass::complex<float>> { | |||||
static NumericTypeID const kId = NumericTypeID::kCF32; | |||||
}; | |||||
template <> | |||||
struct NumericTypeMap<cutlass::complex<double>> { | |||||
static NumericTypeID const kId = NumericTypeID::kCF64; | |||||
}; | |||||
template <> | |||||
struct NumericTypeMap<cutlass::bfloat16_t> { | |||||
static NumericTypeID const kId = NumericTypeID::kBF16; | |||||
}; | |||||
template <> | |||||
struct NumericTypeMap<cutlass::tfloat32_t> { | |||||
static NumericTypeID const kId = NumericTypeID::kTF32; | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
template <typename T> | |||||
struct MathOperationMap { | |||||
static MathOperationID const kId = MathOperationID::kInvalid; | |||||
}; | |||||
template <> | |||||
struct MathOperationMap<cutlass::arch::OpMultiplyAdd> { | |||||
static MathOperationID const kId = MathOperationID::kMultiplyAdd; | |||||
}; | |||||
template <> | |||||
struct MathOperationMap<cutlass::arch::OpMultiplyAddFastBF16> { | |||||
static MathOperationID const kId = MathOperationID::kMultiplyAddFastBF16; | |||||
}; | |||||
template <> | |||||
struct MathOperationMap<cutlass::arch::OpMultiplyAddFastF16> { | |||||
static MathOperationID const kId = MathOperationID::kMultiplyAddFastF16; | |||||
}; | |||||
template <> | |||||
struct MathOperationMap<cutlass::arch::OpMultiplyAddSaturate> { | |||||
static MathOperationID const kId = MathOperationID::kMultiplyAddSaturate; | |||||
}; | |||||
template <> | |||||
struct MathOperationMap<cutlass::arch::OpMultiplyAddComplex> { | |||||
static MathOperationID const kId = MathOperationID::kMultiplyAddComplex; | |||||
}; | |||||
template <> | |||||
struct MathOperationMap<cutlass::arch::OpMultiplyAddGaussianComplex> { | |||||
static MathOperationID const kId = | |||||
MathOperationID::kMultiplyAddGaussianComplex; | |||||
}; | |||||
template <> | |||||
struct MathOperationMap<cutlass::arch::OpXorPopc> { | |||||
static MathOperationID const kId = MathOperationID::kXorPopc; | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
template <typename T> | |||||
struct LayoutMap; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::ColumnMajor> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kColumnMajor; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::RowMajor> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kRowMajor; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<2>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK2; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::RowMajorInterleaved<2>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK2; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<4>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK4; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::RowMajorInterleaved<4>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK4; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<16>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK16; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::RowMajorInterleaved<16>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK16; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<32>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK32; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::RowMajorInterleaved<32>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK32; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<64>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK64; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::RowMajorInterleaved<64>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK64; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::TensorNCHW> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kTensorNCHW; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::TensorNHWC> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kTensorNHWC; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::TensorNDHWC> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kTensorNDHWC; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::TensorNCxHWx<4>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kTensorNC4HW4; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::TensorNCxHWx<8>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kTensorNC8HW8; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::TensorNCxHWx<16>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kTensorNC16HW16; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::TensorNCxHWx<32>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kTensorNC32HW32; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::TensorNCxHWx<64>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kTensorNC64HW64; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::TensorCxRSKx<4>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kTensorC4RSK4; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::TensorCxRSKx<8>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kTensorC8RSK8; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::TensorCxRSKx<16>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kTensorC16RSK16; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::TensorCxRSKx<32>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kTensorC32RSK32; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::TensorCxRSKx<64>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kTensorC64RSK64; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::TensorKxRSCx<4>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kTensorK4RSC4; | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
template <typename T> | |||||
struct OpcodeClassMap; | |||||
template <> | |||||
struct OpcodeClassMap<arch::OpClassSimt> { | |||||
static OpcodeClassID const kId = OpcodeClassID::kSimt; | |||||
}; | |||||
template <> | |||||
struct OpcodeClassMap<arch::OpClassTensorOp> { | |||||
static OpcodeClassID const kId = OpcodeClassID::kTensorOp; | |||||
}; | |||||
template <> | |||||
struct OpcodeClassMap<arch::OpClassWmmaTensorOp> { | |||||
static OpcodeClassID const kId = OpcodeClassID::kWmmaTensorOp; | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
template <typename T> | |||||
struct ArchTagMap; | |||||
template <> | |||||
struct ArchTagMap<arch::Sm50> { | |||||
static ArchTagID const kId = ArchTagID::kSm50; | |||||
}; | |||||
template <> | |||||
struct ArchTagMap<arch::Sm60> { | |||||
static ArchTagID const kId = ArchTagID::kSm60; | |||||
}; | |||||
template <> | |||||
struct ArchTagMap<arch::Sm61> { | |||||
static ArchTagID const kId = ArchTagID::kSm61; | |||||
}; | |||||
template <> | |||||
struct ArchTagMap<arch::Sm70> { | |||||
static ArchTagID const kId = ArchTagID::kSm70; | |||||
}; | |||||
template <> | |||||
struct ArchTagMap<arch::Sm72> { | |||||
static ArchTagID const kId = ArchTagID::kSm72; | |||||
}; | |||||
template <> | |||||
struct ArchTagMap<arch::Sm75> { | |||||
static ArchTagID const kId = ArchTagID::kSm75; | |||||
}; | |||||
template <> | |||||
struct ArchTagMap<arch::Sm80> { | |||||
static ArchTagID const kId = ArchTagID::kSm80; | |||||
}; | |||||
template <> | |||||
struct ArchTagMap<arch::Sm86> { | |||||
static ArchTagID const kId = ArchTagID::kSm86; | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
template <cutlass::ComplexTransform Transform> | |||||
struct ComplexTransformMap; | |||||
template <> | |||||
struct ComplexTransformMap<cutlass::ComplexTransform::kNone> { | |||||
static cutlass::library::ComplexTransform const kId = | |||||
cutlass::library::ComplexTransform::kNone; | |||||
}; | |||||
template <> | |||||
struct ComplexTransformMap<cutlass::ComplexTransform::kConjugate> { | |||||
static cutlass::library::ComplexTransform const kId = | |||||
cutlass::library::ComplexTransform::kConjugate; | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
template <cutlass::conv::Mode T> | |||||
struct ConvModeMap; | |||||
template <> | |||||
struct ConvModeMap<conv::Mode::kCrossCorrelation> { | |||||
static ConvModeID const kId = ConvModeID::kCrossCorrelation; | |||||
}; | |||||
template <> | |||||
struct ConvModeMap<conv::Mode::kConvolution> { | |||||
static ConvModeID const kId = ConvModeID::kConvolution; | |||||
}; | |||||
template <cutlass::conv::Operator T> | |||||
struct ConvKindMap; | |||||
template <> | |||||
struct ConvKindMap<conv::Operator::kFprop> { | |||||
static ConvKind const kId = ConvKind::kFprop; | |||||
}; | |||||
template <> | |||||
struct ConvKindMap<conv::Operator::kDgrad> { | |||||
static ConvKind const kId = ConvKind::kDgrad; | |||||
}; | |||||
template <> | |||||
struct ConvKindMap<conv::Operator::kWgrad> { | |||||
static ConvKind const kId = ConvKind::kWgrad; | |||||
}; | |||||
template <cutlass::conv::IteratorAlgorithm T> | |||||
struct IteratorAlgorithmMap; | |||||
template <> | |||||
struct IteratorAlgorithmMap<conv::IteratorAlgorithm::kAnalytic> { | |||||
static IteratorAlgorithmID const kId = IteratorAlgorithmID::kAnalytic; | |||||
}; | |||||
template <> | |||||
struct IteratorAlgorithmMap<conv::IteratorAlgorithm::kOptimized> { | |||||
static IteratorAlgorithmID const kId = IteratorAlgorithmID::kOptimized; | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
template <typename T> | |||||
struct ThreadblockSwizzleMap; | |||||
template <int N> | |||||
struct ThreadblockSwizzleMap< | |||||
gemm::threadblock::GemmIdentityThreadblockSwizzle<N>> { | |||||
static ThreadblockSwizzleID const kId = ThreadblockSwizzleID::kGemmIdentity; | |||||
}; | |||||
template <> | |||||
struct ThreadblockSwizzleMap< | |||||
gemm::threadblock::GemmHorizontalThreadblockSwizzle> { | |||||
static ThreadblockSwizzleID const kId = | |||||
ThreadblockSwizzleID::kGemmHorizontal; | |||||
}; | |||||
template <> | |||||
struct ThreadblockSwizzleMap< | |||||
gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle> { | |||||
static ThreadblockSwizzleID const kId = | |||||
ThreadblockSwizzleID::kGemmBatchedIdentity; | |||||
}; | |||||
template <int N> | |||||
struct ThreadblockSwizzleMap< | |||||
gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<N>> { | |||||
static ThreadblockSwizzleID const kId = | |||||
ThreadblockSwizzleID::kGemmSplitKIdentity; | |||||
}; | |||||
template <> | |||||
struct ThreadblockSwizzleMap< | |||||
gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle> { | |||||
static ThreadblockSwizzleID const kId = | |||||
ThreadblockSwizzleID::kGemmSplitKHorizontal; | |||||
}; | |||||
template <> | |||||
struct ThreadblockSwizzleMap< | |||||
gemm::threadblock::GemvBatchedStridedThreadblockDefaultSwizzle> { | |||||
static ThreadblockSwizzleID const kId = | |||||
ThreadblockSwizzleID::kGemvBatchedStridedDefault; | |||||
}; | |||||
template <> | |||||
struct ThreadblockSwizzleMap< | |||||
gemm::threadblock::GemvBatchedStridedThreadblockReductionSwizzle> { | |||||
static ThreadblockSwizzleID const kId = | |||||
ThreadblockSwizzleID::kGemvBatchedStridedReduction; | |||||
}; | |||||
template <> | |||||
struct ThreadblockSwizzleMap< | |||||
conv::threadblock::ConvolutionFpropCxRSKxThreadblockSwizzle> { | |||||
static ThreadblockSwizzleID const kId = | |||||
ThreadblockSwizzleID::kConvolutionFpropCxRSKx; | |||||
}; | |||||
template <> | |||||
struct ThreadblockSwizzleMap< | |||||
conv::threadblock::ConvolutionDgradCxRSKxThreadblockSwizzle> { | |||||
static ThreadblockSwizzleID const kId = | |||||
ThreadblockSwizzleID::kConvolutionDgradCxRSKx; | |||||
}; | |||||
template <> | |||||
struct ThreadblockSwizzleMap< | |||||
conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle> { | |||||
static ThreadblockSwizzleID const kId = | |||||
ThreadblockSwizzleID::kConvolutionFpropNCxHWx; | |||||
}; | |||||
template <> | |||||
struct ThreadblockSwizzleMap< | |||||
conv::threadblock::ConvolutionFpropTransThreadblockSwizzle> { | |||||
static ThreadblockSwizzleID const kId = | |||||
ThreadblockSwizzleID::kConvolutionFpropTrans; | |||||
}; | |||||
template <> | |||||
struct ThreadblockSwizzleMap< | |||||
conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle> { | |||||
static ThreadblockSwizzleID const kId = | |||||
ThreadblockSwizzleID::kConvolutionDgradNCxHWx; | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
template <typename Element, typename Layout> | |||||
TensorDescription make_TensorDescription(int alignment = 1) { | |||||
TensorDescription desc; | |||||
desc.element = NumericTypeMap<Element>::kId; | |||||
desc.layout = LayoutMap<Layout>::kId; | |||||
desc.alignment = alignment; | |||||
desc.log_extent_range = | |||||
int(sizeof(typename Layout::TensorCoord::Index) - 1) * 8; | |||||
desc.log_stride_range = int(sizeof(typename Layout::Stride::Index) - 1) * 8; | |||||
return desc; | |||||
} | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
} // namespace library | |||||
} // namespace cutlass | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// |
@@ -0,0 +1,96 @@ | |||||
/*************************************************************************************************** | |||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. | |||||
* | |||||
* Redistribution and use in source and binary forms, with or without | |||||
*modification, are permitted provided that the following conditions are met: | |||||
* * Redistributions of source code must retain the above copyright notice, | |||||
*this list of conditions and the following disclaimer. | |||||
* * Redistributions in binary form must reproduce the above copyright | |||||
*notice, this list of conditions and the following disclaimer in the | |||||
*documentation and/or other materials provided with the distribution. | |||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
*contributors may be used to endorse or promote products derived from this | |||||
*software without specific prior written permission. | |||||
* | |||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
* | |||||
**************************************************************************************************/ | |||||
/** | |||||
* \file dnn/src/cuda/cutlass/manifest.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include <memory> | |||||
#include "src/cuda/cutlass/manifest.h" | |||||
namespace cutlass { | |||||
namespace library { | |||||
////////////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
/// Top-level initialization | |||||
Status Manifest::initialize() { | |||||
if (!operations_.empty()) { | |||||
operations_.clear(); | |||||
} | |||||
// initialize procedurally generated cutlass op in manifest object | |||||
initialize_all(*this); | |||||
return Status::kSuccess; | |||||
} | |||||
/// Used for initialization | |||||
void Manifest::reserve(size_t operation_count) { | |||||
operations_.reserve(operation_count); | |||||
} | |||||
/// Graceful shutdown | |||||
Status Manifest::release() { | |||||
operations_.clear(); | |||||
return Status::kSuccess; | |||||
} | |||||
/// Appends an operation and takes ownership | |||||
void Manifest::append(Operation* operation_ptr) { | |||||
operations_.emplace_back(operation_ptr); | |||||
} | |||||
/// Returns an iterator to the first operation | |||||
OperationVector const& Manifest::operations() const { | |||||
return operations_; | |||||
} | |||||
/// Returns a const iterator | |||||
OperationVector::const_iterator Manifest::begin() const { | |||||
return operations_.begin(); | |||||
} | |||||
/// Returns a const iterator | |||||
OperationVector::const_iterator Manifest::end() const { | |||||
return operations_.end(); | |||||
} | |||||
/////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
} // namespace library | |||||
} // namespace cutlass | |||||
/////////////////////////////////////////////////////////////////////////////////////////////////// |
@@ -0,0 +1,108 @@ | |||||
/*************************************************************************************************** | |||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. | |||||
* | |||||
* Redistribution and use in source and binary forms, with or without | |||||
*modification, are permitted provided that the following conditions are met: | |||||
* * Redistributions of source code must retain the above copyright notice, | |||||
*this list of conditions and the following disclaimer. | |||||
* * Redistributions in binary form must reproduce the above copyright | |||||
*notice, this list of conditions and the following disclaimer in the | |||||
*documentation and/or other materials provided with the distribution. | |||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
*contributors may be used to endorse or promote products derived from this | |||||
*software without specific prior written permission. | |||||
* | |||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
* | |||||
**************************************************************************************************/ | |||||
/** | |||||
* \file dnn/src/cuda/cutlass/manifest.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include <list> | |||||
#include <map> | |||||
#include <memory> | |||||
#include "src/cuda/cutlass/library.h" | |||||
/////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
namespace cutlass { | |||||
namespace library { | |||||
/////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
// Forward declaration | |||||
class Manifest; | |||||
// init and insert all cutlass gemm operations in manifest object (procedurally | |||||
// generated using generator.py) | |||||
void initialize_all(Manifest& manifest); | |||||
///////////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
/// List of operations | |||||
using OperationVector = std::vector<std::unique_ptr<Operation>>; | |||||
/////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
/// Manifest of CUTLASS Library | |||||
class Manifest { | |||||
private: | |||||
/// Operation provider | |||||
Provider provider_; | |||||
/// Global list of operations | |||||
OperationVector operations_; | |||||
public: | |||||
Manifest(Provider provider = library::Provider::kCUTLASS) | |||||
: provider_(provider) {} | |||||
/// Top-level initialization | |||||
Status initialize(); | |||||
/// Used for initialization | |||||
void reserve(size_t operation_count); | |||||
/// Graceful shutdown | |||||
Status release(); | |||||
/// Appends an operation and takes ownership | |||||
void append(Operation* operation_ptr); | |||||
/// Returns an iterator to the first operation | |||||
OperationVector const& operations() const; | |||||
/// Returns a const iterator | |||||
OperationVector::const_iterator begin() const; | |||||
/// Returns a const iterator | |||||
OperationVector::const_iterator end() const; | |||||
}; | |||||
/////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
} // namespace library | |||||
} // namespace cutlass | |||||
/////////////////////////////////////////////////////////////////////////////////////////////////// |
@@ -0,0 +1,179 @@ | |||||
/*************************************************************************************************** | |||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
* | |||||
* Redistribution and use in source and binary forms, with or without | |||||
*modification, are permitted provided that the following conditions are met: | |||||
* * Redistributions of source code must retain the above copyright notice, | |||||
*this list of conditions and the following disclaimer. | |||||
* * Redistributions in binary form must reproduce the above copyright | |||||
*notice, this list of conditions and the following disclaimer in the | |||||
*documentation and/or other materials provided with the distribution. | |||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
*contributors may be used to endorse or promote products derived from this | |||||
*software without specific prior written permission. | |||||
* | |||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
* | |||||
**************************************************************************************************/ | |||||
/** | |||||
* \file dnn/src/cuda/cutlass/operation_table.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/common/utils.h" | |||||
#include "src/cuda/cutlass/operation_table.h" | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
namespace cutlass { | |||||
namespace library { | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
GemmKey get_gemm_key_from_desc(const GemmDescription& desc) { | |||||
GemmKey key; | |||||
key.element_A = desc.A.element; | |||||
key.layout_A = desc.A.layout; | |||||
key.element_B = desc.B.element; | |||||
key.layout_B = desc.B.layout; | |||||
key.element_C = desc.C.element; | |||||
key.layout_C = desc.C.layout; | |||||
key.threadblock_shape_m = desc.tile_description.threadblock_shape.m(); | |||||
key.threadblock_shape_n = desc.tile_description.threadblock_shape.n(); | |||||
key.threadblock_shape_k = desc.tile_description.threadblock_shape.k(); | |||||
key.warp_shape_m = desc.tile_description.threadblock_shape.m() / | |||||
desc.tile_description.warp_count.m(); | |||||
key.warp_shape_n = desc.tile_description.threadblock_shape.n() / | |||||
desc.tile_description.warp_count.n(); | |||||
key.warp_shape_k = desc.tile_description.threadblock_shape.k() / | |||||
desc.tile_description.warp_count.k(); | |||||
key.instruction_shape_m = | |||||
desc.tile_description.math_instruction.instruction_shape.m(); | |||||
key.instruction_shape_n = | |||||
desc.tile_description.math_instruction.instruction_shape.n(); | |||||
key.instruction_shape_k = | |||||
desc.tile_description.math_instruction.instruction_shape.k(); | |||||
key.stages = desc.stages; | |||||
key.split_k_mode = desc.split_k_mode; | |||||
return key; | |||||
} | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
ConvolutionKey get_convolution_key_from_desc( | |||||
const ConvolutionDescription& desc) { | |||||
ConvolutionKey key; | |||||
key.conv_op = desc.conv_op; | |||||
key.element_src = desc.src.element; | |||||
key.layout_src = desc.src.layout; | |||||
key.element_filter = desc.filter.element; | |||||
key.layout_filter = desc.filter.layout; | |||||
key.element_dst = desc.dst.element; | |||||
key.layout_dst = desc.dst.layout; | |||||
key.element_bias = desc.bias.element; | |||||
key.layout_bias = desc.bias.layout; | |||||
key.convolution_type = desc.convolution_type; | |||||
key.threadblock_shape_m = desc.tile_description.threadblock_shape.m(); | |||||
key.threadblock_shape_n = desc.tile_description.threadblock_shape.n(); | |||||
key.threadblock_shape_k = desc.tile_description.threadblock_shape.k(); | |||||
key.warp_shape_m = desc.tile_description.threadblock_shape.m() / | |||||
desc.tile_description.warp_count.m(); | |||||
key.warp_shape_n = desc.tile_description.threadblock_shape.n() / | |||||
desc.tile_description.warp_count.n(); | |||||
key.warp_shape_k = desc.tile_description.threadblock_shape.k() / | |||||
desc.tile_description.warp_count.k(); | |||||
key.instruction_shape_m = | |||||
desc.tile_description.math_instruction.instruction_shape.m(); | |||||
key.instruction_shape_n = | |||||
desc.tile_description.math_instruction.instruction_shape.n(); | |||||
key.instruction_shape_k = | |||||
desc.tile_description.math_instruction.instruction_shape.k(); | |||||
key.epilogue_type = desc.epilogue_type; | |||||
key.stages = desc.tile_description.threadblock_stages; | |||||
key.need_load_from_const_mem = desc.need_load_from_const_mem; | |||||
key.without_shared_load = desc.without_shared_load; | |||||
return key; | |||||
} | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
void OperationTable::append(Manifest const& manifest) { | |||||
// Insert operations into appropriate data structure | |||||
for (auto const& operation : manifest) { | |||||
OperationDescription const& desc = operation->description(); | |||||
// insert all gemm operations into operation table | |||||
if (desc.kind == OperationKind::kGemm) { | |||||
GemmKey key = get_gemm_key_from_desc( | |||||
static_cast<GemmDescription const&>(desc)); | |||||
gemm_operations[key].push_back(operation.get()); | |||||
} | |||||
// insert all conv operations into operation table | |||||
if (desc.kind == OperationKind::kConvolution) { | |||||
ConvolutionKey key = get_convolution_key_from_desc( | |||||
static_cast<ConvolutionDescription const&>(desc)); | |||||
convolution_operations[key].push_back(operation.get()); | |||||
} | |||||
} | |||||
} | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
Operation const* OperationTable::find_op(GemmKey const& key) const { | |||||
megdnn_assert(gemm_operations.count(key) > 0, | |||||
"key not found in cutlass operation table"); | |||||
auto const& ops = gemm_operations.at(key); | |||||
megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", | |||||
ops.size()); | |||||
return ops[0]; | |||||
} | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
Operation const* OperationTable::find_op(ConvolutionKey const& key) const { | |||||
megdnn_assert(convolution_operations.count(key) > 0, | |||||
"key not found in cutlass operation table"); | |||||
auto const& ops = convolution_operations.at(key); | |||||
megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", | |||||
ops.size()); | |||||
return ops[0]; | |||||
} | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
} // namespace library | |||||
} // namespace cutlass | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// |
@@ -0,0 +1,334 @@ | |||||
/*************************************************************************************************** | |||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
* | |||||
* Redistribution and use in source and binary forms, with or without | |||||
*modification, are permitted provided that the following conditions are met: | |||||
* * Redistributions of source code must retain the above copyright notice, | |||||
*this list of conditions and the following disclaimer. | |||||
* * Redistributions in binary form must reproduce the above copyright | |||||
*notice, this list of conditions and the following disclaimer in the | |||||
*documentation and/or other materials provided with the distribution. | |||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
*contributors may be used to endorse or promote products derived from this | |||||
*software without specific prior written permission. | |||||
* | |||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
* | |||||
**************************************************************************************************/ | |||||
/** | |||||
* \file dnn/src/cuda/cutlass/operation_table.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include <unordered_map> | |||||
#include "src/common/hash_ct.h" | |||||
#include "src/cuda/cutlass/manifest.h" | |||||
#include "src/cuda/cutlass/util.h" | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
namespace cutlass { | |||||
namespace library { | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
class Hash { | |||||
public: | |||||
Hash() : m_val(0) {} | |||||
Hash& update(const void* ptr, size_t len) { | |||||
m_val += megdnn::XXHash64CT::hash((const char*)ptr, len, 123456); | |||||
return *this; | |||||
} | |||||
uint64_t digest() const { return m_val; } | |||||
private: | |||||
uint64_t m_val; | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
// Data Structures for GemmOperationMap | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
struct GemmKey { | |||||
NumericTypeID element_A; | |||||
LayoutTypeID layout_A; | |||||
NumericTypeID element_B; | |||||
LayoutTypeID layout_B; | |||||
NumericTypeID element_C; | |||||
LayoutTypeID layout_C; | |||||
int threadblock_shape_m; | |||||
int threadblock_shape_n; | |||||
int threadblock_shape_k; | |||||
int warp_shape_m; | |||||
int warp_shape_n; | |||||
int warp_shape_k; | |||||
int instruction_shape_m; | |||||
int instruction_shape_n; | |||||
int instruction_shape_k; | |||||
int stages; | |||||
SplitKMode split_k_mode; | |||||
inline bool operator==(GemmKey const& rhs) const { | |||||
return (element_A == rhs.element_A) && (layout_A == rhs.layout_A) && | |||||
(element_B == rhs.element_B) && (layout_B == rhs.layout_B) && | |||||
(element_C == rhs.element_C) && (layout_C == rhs.layout_C) && | |||||
(threadblock_shape_m == rhs.threadblock_shape_m) && | |||||
(threadblock_shape_n == rhs.threadblock_shape_n) && | |||||
(threadblock_shape_k == rhs.threadblock_shape_k) && | |||||
(warp_shape_m == rhs.warp_shape_m) && | |||||
(warp_shape_n == rhs.warp_shape_n) && | |||||
(warp_shape_k == rhs.warp_shape_k) && | |||||
(instruction_shape_m == rhs.instruction_shape_m) && | |||||
(instruction_shape_n == rhs.instruction_shape_n) && | |||||
(instruction_shape_k == rhs.instruction_shape_k) && | |||||
(stages == rhs.stages) && (split_k_mode == rhs.split_k_mode); | |||||
} | |||||
inline bool operator!=(GemmKey const& rhs) const { return !(*this == rhs); } | |||||
inline std::string str() const { | |||||
auto tuple_to_str = [](int m, int n, int k) -> std::string { | |||||
return std::to_string(m) + " x " + std::to_string(n) + " x " + | |||||
std::to_string(k); | |||||
}; | |||||
std::string threadblock_shape_str = tuple_to_str( | |||||
threadblock_shape_m, threadblock_shape_n, threadblock_shape_k); | |||||
std::string warp_shape_str = | |||||
tuple_to_str(warp_shape_m, warp_shape_n, warp_shape_k); | |||||
std::string instruction_shape_str = tuple_to_str( | |||||
instruction_shape_m, instruction_shape_n, instruction_shape_k); | |||||
return std::string("{") + "\n element_A: " + to_string(element_A) + | |||||
"\n layout_A: " + to_string(layout_A) + | |||||
"\n element_B: " + to_string(element_B) + | |||||
"\n layout_B: " + to_string(layout_B) + | |||||
"\n element_C: " + to_string(element_C) + | |||||
"\n layout_C: " + to_string(layout_C) + | |||||
"\n threadblock_shape: " + threadblock_shape_str + | |||||
"\n warp_shape: " + warp_shape_str + | |||||
"\n instruction_shape: " + instruction_shape_str + | |||||
"\n stages: " + std::to_string(stages) + | |||||
"\n split_k_mode: " + to_string(split_k_mode) + "\n}"; | |||||
} | |||||
}; | |||||
struct GemmKeyHasher { | |||||
inline size_t operator()(GemmKey const& key) const { | |||||
return Hash() | |||||
.update(&key.element_A, sizeof(key.element_A)) | |||||
.update(&key.layout_A, sizeof(key.layout_A)) | |||||
.update(&key.element_B, sizeof(key.element_B)) | |||||
.update(&key.layout_B, sizeof(key.layout_B)) | |||||
.update(&key.element_C, sizeof(key.element_C)) | |||||
.update(&key.layout_C, sizeof(key.layout_C)) | |||||
.update(&key.threadblock_shape_m, | |||||
sizeof(key.threadblock_shape_m)) | |||||
.update(&key.threadblock_shape_n, | |||||
sizeof(key.threadblock_shape_n)) | |||||
.update(&key.threadblock_shape_k, | |||||
sizeof(key.threadblock_shape_k)) | |||||
.update(&key.warp_shape_m, sizeof(key.warp_shape_m)) | |||||
.update(&key.warp_shape_n, sizeof(key.warp_shape_n)) | |||||
.update(&key.warp_shape_k, sizeof(key.warp_shape_k)) | |||||
.update(&key.stages, sizeof(key.stages)) | |||||
.update(&key.split_k_mode, sizeof(key.split_k_mode)) | |||||
.digest(); | |||||
} | |||||
}; | |||||
using GemmOperationMap = | |||||
std::unordered_map<GemmKey, std::vector<Operation const*>, | |||||
GemmKeyHasher>; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
// Data Structures for ConvolutionOperationMap | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
struct ConvolutionKey { | |||||
conv::Operator conv_op; | |||||
library::NumericTypeID element_src; | |||||
library::LayoutTypeID layout_src; | |||||
library::NumericTypeID element_filter; | |||||
library::LayoutTypeID layout_filter; | |||||
library::NumericTypeID element_dst; | |||||
library::LayoutTypeID layout_dst; | |||||
library::NumericTypeID element_bias; | |||||
library::LayoutTypeID layout_bias; | |||||
conv::ConvType convolution_type; | |||||
int threadblock_shape_m; | |||||
int threadblock_shape_n; | |||||
int threadblock_shape_k; | |||||
int warp_shape_m; | |||||
int warp_shape_n; | |||||
int warp_shape_k; | |||||
int instruction_shape_m; | |||||
int instruction_shape_n; | |||||
int instruction_shape_k; | |||||
epilogue::EpilogueType epilogue_type; | |||||
int stages; | |||||
bool need_load_from_const_mem; | |||||
bool without_shared_load; | |||||
inline bool operator==(ConvolutionKey const& rhs) const { | |||||
return (conv_op == rhs.conv_op) && (element_src == rhs.element_src) && | |||||
(layout_src == rhs.layout_src) && | |||||
(element_filter == rhs.element_filter) && | |||||
(layout_filter == rhs.layout_filter) && | |||||
(element_dst == rhs.element_dst) && | |||||
(layout_dst == rhs.layout_dst) && | |||||
(element_bias == rhs.element_bias) && | |||||
(layout_bias == rhs.layout_bias) && | |||||
(convolution_type == rhs.convolution_type) && | |||||
(threadblock_shape_m == rhs.threadblock_shape_m) && | |||||
(threadblock_shape_n == rhs.threadblock_shape_n) && | |||||
(threadblock_shape_k == rhs.threadblock_shape_k) && | |||||
(warp_shape_m == rhs.warp_shape_m) && | |||||
(warp_shape_n == rhs.warp_shape_n) && | |||||
(warp_shape_k == rhs.warp_shape_k) && | |||||
(instruction_shape_m == rhs.instruction_shape_m) && | |||||
(instruction_shape_n == rhs.instruction_shape_n) && | |||||
(instruction_shape_k == rhs.instruction_shape_k) && | |||||
(epilogue_type == rhs.epilogue_type) && (stages == rhs.stages) && | |||||
(need_load_from_const_mem == rhs.need_load_from_const_mem) && | |||||
(without_shared_load == rhs.without_shared_load); | |||||
} | |||||
inline bool operator!=(ConvolutionKey const& rhs) const { | |||||
return !(*this == rhs); | |||||
} | |||||
inline std::string str() const { | |||||
auto tuple_to_str = [](int m, int n, int k) -> std::string { | |||||
return std::to_string(m) + " x " + std::to_string(n) + " x " + | |||||
std::to_string(k); | |||||
}; | |||||
std::string threadblock_shape_str = tuple_to_str( | |||||
threadblock_shape_m, threadblock_shape_n, threadblock_shape_k); | |||||
std::string warp_shape_str = | |||||
tuple_to_str(warp_shape_m, warp_shape_n, warp_shape_k); | |||||
std::string instruction_shape_str = tuple_to_str( | |||||
instruction_shape_m, instruction_shape_n, instruction_shape_k); | |||||
return std::string("{") + "\n conv_op: " + to_string(conv_op) + | |||||
"\n element_src: " + to_string(element_src) + | |||||
"\n layout_src: " + to_string(layout_src) + | |||||
"\n element_filter: " + to_string(element_filter) + | |||||
"\n layout_filter: " + to_string(layout_filter) + | |||||
"\n element_dst: " + to_string(element_dst) + | |||||
"\n layout_dst: " + to_string(layout_dst) + | |||||
"\n element_bias: " + to_string(element_bias) + | |||||
"\n layout_bias: " + to_string(layout_bias) + | |||||
"\n convolution_type: " + to_string(convolution_type) + | |||||
"\n threadblock_shape: " + threadblock_shape_str + | |||||
"\n warp_shape: " + warp_shape_str + | |||||
"\n instruction_shape: " + instruction_shape_str + | |||||
"\n epilogue_type: " + to_string(epilogue_type) + | |||||
"\n stages: " + std::to_string(stages) + | |||||
"\n need_load_from_const_mem: " + | |||||
to_string(need_load_from_const_mem) + | |||||
"\n without_shared_load: " + to_string(without_shared_load) + | |||||
"\n}"; | |||||
} | |||||
}; | |||||
struct ConvolutionKeyHasher { | |||||
inline size_t operator()(ConvolutionKey const& key) const { | |||||
return Hash() | |||||
.update(&key.conv_op, sizeof(key.conv_op)) | |||||
.update(&key.conv_op, sizeof(key.conv_op)) | |||||
.update(&key.element_src, sizeof(key.element_src)) | |||||
.update(&key.layout_src, sizeof(key.layout_src)) | |||||
.update(&key.element_filter, sizeof(key.element_filter)) | |||||
.update(&key.layout_filter, sizeof(key.layout_filter)) | |||||
.update(&key.element_dst, sizeof(key.element_dst)) | |||||
.update(&key.layout_dst, sizeof(key.layout_dst)) | |||||
.update(&key.element_bias, sizeof(key.element_bias)) | |||||
.update(&key.layout_bias, sizeof(key.layout_bias)) | |||||
.update(&key.convolution_type, sizeof(key.convolution_type)) | |||||
.update(&key.threadblock_shape_m, | |||||
sizeof(key.threadblock_shape_m)) | |||||
.update(&key.threadblock_shape_n, | |||||
sizeof(key.threadblock_shape_n)) | |||||
.update(&key.threadblock_shape_k, | |||||
sizeof(key.threadblock_shape_k)) | |||||
.update(&key.warp_shape_m, sizeof(key.warp_shape_m)) | |||||
.update(&key.warp_shape_n, sizeof(key.warp_shape_n)) | |||||
.update(&key.warp_shape_k, sizeof(key.warp_shape_k)) | |||||
.update(&key.instruction_shape_m, | |||||
sizeof(key.instruction_shape_m)) | |||||
.update(&key.instruction_shape_n, | |||||
sizeof(key.instruction_shape_n)) | |||||
.update(&key.instruction_shape_k, | |||||
sizeof(key.instruction_shape_k)) | |||||
.update(&key.epilogue_type, sizeof(key.epilogue_type)) | |||||
.update(&key.stages, sizeof(key.stages)) | |||||
.update(&key.need_load_from_const_mem, | |||||
sizeof(key.need_load_from_const_mem)) | |||||
.update(&key.without_shared_load, | |||||
sizeof(key.without_shared_load)) | |||||
.digest(); | |||||
} | |||||
}; | |||||
using ConvolutionOperationMap = | |||||
std::unordered_map<ConvolutionKey, std::vector<Operation const*>, | |||||
ConvolutionKeyHasher>; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
/// Table of cutlass::library::Operation instances | |||||
class OperationTable { | |||||
public: | |||||
/// Map of all operations of type kGemm | |||||
GemmOperationMap gemm_operations; | |||||
/// Map of all operations of type kConvolution | |||||
ConvolutionOperationMap convolution_operations; | |||||
public: | |||||
void append(Manifest const& manifest); | |||||
Operation const* find_op(GemmKey const& key) const; | |||||
Operation const* find_op(ConvolutionKey const& key) const; | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
} // namespace library | |||||
} // namespace cutlass | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// |
@@ -0,0 +1,72 @@ | |||||
/*************************************************************************************************** | |||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
* | |||||
* Redistribution and use in source and binary forms, with or without | |||||
*modification, are permitted provided that the following conditions are met: | |||||
* * Redistributions of source code must retain the above copyright notice, | |||||
*this list of conditions and the following disclaimer. | |||||
* * Redistributions in binary form must reproduce the above copyright | |||||
*notice, this list of conditions and the following disclaimer in the | |||||
*documentation and/or other materials provided with the distribution. | |||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
*contributors may be used to endorse or promote products derived from this | |||||
*software without specific prior written permission. | |||||
* | |||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
* | |||||
**************************************************************************************************/ | |||||
/** | |||||
* \file dnn/src/cuda/cutlass/singleton.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include <memory> | |||||
#include "src/cuda/cutlass/singleton.h" | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
namespace cutlass { | |||||
namespace library { | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
static std::unique_ptr<Singleton> instance; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
Singleton::Singleton() { | |||||
manifest.initialize(); | |||||
operation_table.append(manifest); | |||||
} | |||||
Singleton const& Singleton::get() { | |||||
if (!instance.get()) { | |||||
instance.reset(new Singleton); | |||||
} | |||||
return *instance.get(); | |||||
} | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
} // namespace library | |||||
} // namespace cutlass | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// |
@@ -0,0 +1,70 @@ | |||||
/*************************************************************************************************** | |||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
* | |||||
* Redistribution and use in source and binary forms, with or without | |||||
*modification, are permitted provided that the following conditions are met: | |||||
* * Redistributions of source code must retain the above copyright notice, | |||||
*this list of conditions and the following disclaimer. | |||||
* * Redistributions in binary form must reproduce the above copyright | |||||
*notice, this list of conditions and the following disclaimer in the | |||||
*documentation and/or other materials provided with the distribution. | |||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
*contributors may be used to endorse or promote products derived from this | |||||
*software without specific prior written permission. | |||||
* | |||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
* | |||||
**************************************************************************************************/ | |||||
/** | |||||
* \file dnn/src/cuda/cutlass/singleton.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/cuda/cutlass/operation_table.h" | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
namespace cutlass { | |||||
namespace library { | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
/// Singleton instance stores a Manifest and Operation table | |||||
class Singleton { | |||||
public: | |||||
/// Manifest object | |||||
Manifest manifest; | |||||
/// Operation table referencing the Manifest | |||||
OperationTable operation_table; | |||||
public: | |||||
Singleton(); | |||||
static Singleton const& get(); | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
} // namespace library | |||||
} // namespace cutlass | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// |
@@ -0,0 +1,218 @@ | |||||
/*************************************************************************************************** | |||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
* | |||||
* Redistribution and use in source and binary forms, with or without | |||||
*modification, are permitted provided that the following conditions are met: | |||||
* * Redistributions of source code must retain the above copyright notice, | |||||
*this list of conditions and the following disclaimer. | |||||
* * Redistributions in binary form must reproduce the above copyright | |||||
*notice, this list of conditions and the following disclaimer in the | |||||
*documentation and/or other materials provided with the distribution. | |||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
*contributors may be used to endorse or promote products derived from this | |||||
*software without specific prior written permission. | |||||
* | |||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
* | |||||
**************************************************************************************************/ | |||||
/** | |||||
* \file dnn/src/cuda/cutlass/util.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/cuda/cutlass/library.h" | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
namespace cutlass { | |||||
namespace library { | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
/// Lexical cast from string | |||||
template <typename T> | |||||
T from_string(std::string const&); | |||||
/// Converts a Provider enumerant to a string | |||||
char const* to_string(Provider provider, bool pretty = false); | |||||
/// Parses a Provider enumerant from a string | |||||
template <> | |||||
Provider from_string<Provider>(std::string const& str); | |||||
/// Converts a GemmKind enumerant to a string | |||||
char const* to_string(GemmKind type, bool pretty = false); | |||||
/// Converts a NumericType enumerant to a string | |||||
char const* to_string(OperationKind type, bool pretty = false); | |||||
/// Parses a NumericType enumerant from a string | |||||
template <> | |||||
OperationKind from_string<OperationKind>(std::string const& str); | |||||
/// Converts a NumericType enumerant to a string | |||||
char const* to_string(NumericTypeID type, bool pretty = false); | |||||
/// Parses a NumericType enumerant from a string | |||||
template <> | |||||
NumericTypeID from_string<NumericTypeID>(std::string const& str); | |||||
/// Returns the size of a data type in bits | |||||
int sizeof_bits(NumericTypeID type); | |||||
/// Returns true if the numeric type is a complex data type or false if | |||||
/// real-valued. | |||||
bool is_complex_type(NumericTypeID type); | |||||
/// Returns the real-valued type underlying a type (only different from 'type' | |||||
/// if complex) | |||||
NumericTypeID get_real_type(NumericTypeID type); | |||||
/// Returns true if numeric type is integer | |||||
bool is_integer_type(NumericTypeID type); | |||||
/// Returns true if numeric type is signed | |||||
bool is_signed_type(NumericTypeID type); | |||||
/// Returns true if numeric type is a signed integer | |||||
bool is_signed_integer(NumericTypeID type); | |||||
/// returns true if numeric type is an unsigned integer | |||||
bool is_unsigned_integer(NumericTypeID type); | |||||
/// Returns true if numeric type is floating-point type | |||||
bool is_float_type(NumericTypeID type); | |||||
/// To string method for cutlass::Status | |||||
char const* to_string(Status status, bool pretty = false); | |||||
/// Converts a LayoutTypeID enumerant to a string | |||||
char const* to_string(LayoutTypeID layout, bool pretty = false); | |||||
/// Parses a LayoutType enumerant from a string | |||||
template <> | |||||
LayoutTypeID from_string<LayoutTypeID>(std::string const& str); | |||||
/// Returns the rank of a layout's stride base on the LayoutTypeID | |||||
int get_layout_stride_rank(LayoutTypeID layout_id); | |||||
/// Converts a OpcodeClassID enumerant to a string | |||||
char const* to_string(OpcodeClassID type, bool pretty = false); | |||||
/// Converts a OpcodeClassID enumerant from a string | |||||
template <> | |||||
OpcodeClassID from_string<OpcodeClassID>(std::string const& str); | |||||
/// Converts a ComplexTransform enumerant to a string | |||||
char const* to_string(ComplexTransform type, bool pretty = false); | |||||
/// Converts a ComplexTransform enumerant from a string | |||||
template <> | |||||
ComplexTransform from_string<ComplexTransform>(std::string const& str); | |||||
/// Converts a SplitKMode enumerant to a string | |||||
char const* to_string(SplitKMode split_k_mode, bool pretty = false); | |||||
/// Converts a SplitKMode enumerant from a string | |||||
template <> | |||||
SplitKMode from_string<SplitKMode>(std::string const& str); | |||||
/// Converts a ConvModeID enumerant to a string | |||||
char const* to_string(ConvModeID type, bool pretty = false); | |||||
/// Converts a ConvModeID enumerant from a string | |||||
template <> | |||||
ConvModeID from_string<ConvModeID>(std::string const& str); | |||||
/// Converts a IteratorAlgorithmID enumerant to a string | |||||
char const* to_string(IteratorAlgorithmID type, bool pretty = false); | |||||
/// Converts a IteratorAlgorithmID enumerant from a string | |||||
template <> | |||||
IteratorAlgorithmID from_string<IteratorAlgorithmID>(std::string const& str); | |||||
/// Converts a ConvKind enumerant to a string | |||||
char const* to_string(ConvKind type, bool pretty = false); | |||||
/// Converts a ConvKind enumerant from a string | |||||
template <> | |||||
ConvKind from_string<ConvKind>(std::string const& str); | |||||
/// Lexical cast from int64_t to string | |||||
std::string lexical_cast(int64_t int_value); | |||||
/// Lexical cast a string to a byte array. Returns true if cast is successful or | |||||
/// false if invalid. | |||||
bool lexical_cast(std::vector<uint8_t>& bytes, NumericTypeID type, | |||||
std::string const& str); | |||||
/// Lexical cast TO a string FROM a byte array. Returns true if cast is | |||||
/// successful or false if invalid. | |||||
std::string lexical_cast(std::vector<uint8_t>& bytes, NumericTypeID type); | |||||
/// Casts from a signed int64 to the destination type. Returns true if | |||||
/// successful. | |||||
bool cast_from_int64(std::vector<uint8_t>& bytes, NumericTypeID type, | |||||
int64_t src); | |||||
/// Casts from an unsigned int64 to the destination type. Returns true if | |||||
/// successful. | |||||
bool cast_from_uint64(std::vector<uint8_t>& bytes, NumericTypeID type, | |||||
uint64_t src); | |||||
/// Casts from a real value represented as a double to the destination type. | |||||
/// Returns true if successful. | |||||
bool cast_from_double(std::vector<uint8_t>& bytes, NumericTypeID type, | |||||
double src); | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
/// Converts a conv::Operator enumerant to a string | |||||
char const* to_string(conv::Operator conv_op, bool pretty = false); | |||||
/// Converts a ConvType enumerant to a string | |||||
char const* to_string(conv::ConvType type, bool pretty = false); | |||||
/// Converts an ArchTagID enumerant to a string | |||||
char const* to_string(ArchTagID tag, bool pretty = false); | |||||
/// Converts an EpilogueType enumerant to a string | |||||
char const* to_string(epilogue::EpilogueType type, bool pretty = false); | |||||
/// Converts a ThreadblockSwizzleID enumerant to a string | |||||
char const* to_string(ThreadblockSwizzleID threadblock_swizzle, | |||||
bool pretty = false); | |||||
/// Converts a bool value to a string | |||||
char const* to_string(bool val, bool pretty = false); | |||||
/// Converts a MathOperationID enumerant to a string | |||||
char const* to_string(MathOperationID math_op, bool pretty = false); | |||||
/// Converts an ImplicitGemmMode enumerant to a string | |||||
char const* to_string(conv::ImplicitGemmMode mode, bool pretty = false); | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
} // namespace library | |||||
} // namespace cutlass | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// |
@@ -10,15 +10,14 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/cuda/cutlass/singleton.h" | |||||
#include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
#include "src/cuda/matrix_mul/algos.h" | #include "src/cuda/matrix_mul/algos.h" | ||||
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | |||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#if CUDA_VERSION >= 9020 | #if CUDA_VERSION >= 9020 | ||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
using namespace cutlass_wrapper; | |||||
bool MatrixMulForwardImpl::AlgoFloat32SIMT::is_available( | bool MatrixMulForwardImpl::AlgoFloat32SIMT::is_available( | ||||
const SizeArgs& args) const { | const SizeArgs& args) const { | ||||
@@ -44,25 +43,62 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes( | |||||
} | } | ||||
void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { | void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { | ||||
size_t lda = args.tensor_a.layout.stride[0], | |||||
ldb = args.tensor_b.layout.stride[0], | |||||
ldc = args.tensor_c.layout.stride[0]; | |||||
int64_t lda = args.tensor_a.layout.stride[0], | |||||
ldb = args.tensor_b.layout.stride[0], | |||||
ldc = args.tensor_c.layout.stride[0]; | |||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | ||||
k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | ||||
GemmCoord problem_size{m, n, k}; | |||||
cutlass::gemm::GemmCoord problem_size{m, n, k}; | |||||
auto&& stream = cuda_stream(args.opr->handle()); | auto&& stream = cuda_stream(args.opr->handle()); | ||||
int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | ||||
return cutlass_matrix_mul_float32_simt( | |||||
args.tensor_a.ptr<dt_float32>(), param.transposeA, lda, | |||||
args.tensor_b.ptr<dt_float32>(), param.transposeB, ldb, | |||||
args.tensor_c.ptr<dt_float32>(), ldc, workspace, problem_size, 1.f, | |||||
0.f, | |||||
GemmCoord{m_algo_param.threadblock_m, m_algo_param.threadblock_n, | |||||
m_algo_param.threadblock_k}, | |||||
GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n, | |||||
m_algo_param.warp_k}, | |||||
stream); | |||||
// \note these constants of cutlass epilogue will be passed to struct | |||||
// `GemmArguments` by pointer and interpreted as ElementCompute*, a | |||||
// different dtype here results in undefined epilogue behaviors | |||||
float alpha = 1.f, beta = 0.f; | |||||
using namespace cutlass::library; | |||||
auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor | |||||
: LayoutTypeID::kRowMajor; | |||||
auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | |||||
: LayoutTypeID::kRowMajor; | |||||
GemmKey key{NumericTypeID::kF32, | |||||
layoutA, | |||||
NumericTypeID::kF32, | |||||
layoutB, | |||||
NumericTypeID::kF32, | |||||
LayoutTypeID::kRowMajor, | |||||
m_algo_param.threadblock_m, | |||||
m_algo_param.threadblock_n, | |||||
m_algo_param.threadblock_k, | |||||
m_algo_param.warp_m, | |||||
m_algo_param.warp_n, | |||||
m_algo_param.warp_k, | |||||
1, | |||||
1, | |||||
1, | |||||
2, | |||||
SplitKMode::kNone}; | |||||
const Operation* op = Singleton::get().operation_table.find_op(key); | |||||
GemmArguments gemm_args{problem_size, | |||||
args.tensor_a.raw_ptr, | |||||
args.tensor_b.raw_ptr, | |||||
args.tensor_c.raw_ptr, | |||||
args.tensor_c.raw_ptr, | |||||
lda, | |||||
ldb, | |||||
ldc, | |||||
ldc, | |||||
1, | |||||
&alpha, | |||||
&beta}; | |||||
cutlass_check(op->run(&gemm_args, workspace, stream)); | |||||
} | } | ||||
#endif | #endif | ||||
@@ -10,15 +10,14 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/cuda/cutlass/singleton.h" | |||||
#include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
#include "src/cuda/matrix_mul/algos.h" | #include "src/cuda/matrix_mul/algos.h" | ||||
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | |||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#if CUDA_VERSION >= 9020 | #if CUDA_VERSION >= 9020 | ||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
using namespace cutlass_wrapper; | |||||
bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( | bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( | ||||
const SizeArgs& args) const { | const SizeArgs& args) const { | ||||
@@ -50,26 +49,63 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( | |||||
void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( | void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( | ||||
const ExecArgs& args) const { | const ExecArgs& args) const { | ||||
size_t lda = args.tensor_a.layout.stride[0], | |||||
ldb = args.tensor_b.layout.stride[0], | |||||
ldc = args.tensor_c.layout.stride[0]; | |||||
int64_t lda = args.tensor_a.layout.stride[0], | |||||
ldb = args.tensor_b.layout.stride[0], | |||||
ldc = args.tensor_c.layout.stride[0]; | |||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | ||||
k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | ||||
GemmCoord problem_size{m, n, k}; | |||||
cutlass::gemm::GemmCoord problem_size{m, n, k}; | |||||
int split_k_slices = std::max(1, k / n); | int split_k_slices = std::max(1, k / n); | ||||
auto&& stream = cuda_stream(args.opr->handle()); | auto&& stream = cuda_stream(args.opr->handle()); | ||||
int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | ||||
return cutlass_matrix_mul_float32_simt( | |||||
args.tensor_a.ptr<dt_float32>(), param.transposeA, lda, | |||||
args.tensor_b.ptr<dt_float32>(), param.transposeB, ldb, | |||||
args.tensor_c.ptr<dt_float32>(), ldc, workspace, problem_size, 1.f, | |||||
0.f, | |||||
GemmCoord{m_algo_param.threadblock_m, m_algo_param.threadblock_n, | |||||
m_algo_param.threadblock_k}, | |||||
GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n, | |||||
m_algo_param.warp_k}, | |||||
stream, split_k_slices); | |||||
// \note these constants of cutlass epilogue will be passed to struct | |||||
// `GemmArguments` by pointer and interpreted as ElementCompute*, a | |||||
// different dtype here results in undefined epilogue behaviors | |||||
float alpha = 1.f, beta = 0.f; | |||||
using namespace cutlass::library; | |||||
auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor | |||||
: LayoutTypeID::kRowMajor; | |||||
auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | |||||
: LayoutTypeID::kRowMajor; | |||||
GemmKey key{NumericTypeID::kF32, | |||||
layoutA, | |||||
NumericTypeID::kF32, | |||||
layoutB, | |||||
NumericTypeID::kF32, | |||||
LayoutTypeID::kRowMajor, | |||||
m_algo_param.threadblock_m, | |||||
m_algo_param.threadblock_n, | |||||
m_algo_param.threadblock_k, | |||||
m_algo_param.warp_m, | |||||
m_algo_param.warp_n, | |||||
m_algo_param.warp_k, | |||||
1, | |||||
1, | |||||
1, | |||||
2, | |||||
SplitKMode::kParallel}; | |||||
Operation const* op = Singleton::get().operation_table.find_op(key); | |||||
GemmArguments gemm_args{problem_size, | |||||
args.tensor_a.raw_ptr, | |||||
args.tensor_b.raw_ptr, | |||||
args.tensor_c.raw_ptr, | |||||
args.tensor_c.raw_ptr, | |||||
lda, | |||||
ldb, | |||||
ldc, | |||||
ldc, | |||||
split_k_slices, | |||||
&alpha, | |||||
&beta}; | |||||
cutlass_check(op->run(&gemm_args, workspace, stream)); | |||||
} | } | ||||
#endif | #endif | ||||
@@ -1,157 +0,0 @@ | |||||
/** | |||||
* \file dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
// ignore warning of cutlass | |||||
#include "cuda.h" | |||||
#if __CUDACC_VER_MAJOR__ > 9 || \ | |||||
(__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
#pragma GCC diagnostic push | |||||
#pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
#include "cutlass/gemm/device/gemm.h" | |||||
#include "cutlass/gemm/device/gemm_splitk_parallel.h" | |||||
#include "cutlass/gemm/kernel/default_gemv.h" | |||||
#include "src/common/opr_param_defs_enumv.cuh" | |||||
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | |||||
#pragma GCC diagnostic pop | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
using namespace cutlass_wrapper; | |||||
/* ================= cutlass kernel wrapper for f32 matrix mul ================ | |||||
*/ | |||||
#define DISPATCH(cb) \ | |||||
cb(64, 256, 8, 32, 64, 8); \ | |||||
cb(256, 64, 8, 64, 32, 8); \ | |||||
cb(32, 256, 8, 16, 64, 8); \ | |||||
cb(256, 32, 8, 64, 16, 8); \ | |||||
cb(128, 128, 8, 32, 64, 8); \ | |||||
cb(128, 64, 8, 64, 32, 8); \ | |||||
cb(64, 128, 8, 32, 64, 8); \ | |||||
cb(128, 32, 8, 64, 32, 8); \ | |||||
cb(32, 128, 8, 32, 64, 8); \ | |||||
cb(64, 64, 8, 32, 64, 8); \ | |||||
cb(32, 64, 8, 32, 64, 8); \ | |||||
cb(64, 32, 8, 64, 32, 8); \ | |||||
cb(32, 32, 8, 32, 32, 8); \ | |||||
cb(8, 32, 8, 8, 32, 8); \ | |||||
cb(16, 32, 8, 16, 32, 8); \ | |||||
cb(16, 64, 8, 16, 64, 8); \ | |||||
cb(16, 128, 8, 16, 64, 8); \ | |||||
megdnn_assert(false, \ | |||||
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
"(%dx%dx%d)", \ | |||||
threadblock_shape.m(), threadblock_shape.n(), \ | |||||
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
warp_shape.k()); | |||||
void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt( | |||||
const float* d_A, bool transpose_A, size_t lda, const float* d_B, | |||||
bool transpose_B, size_t ldb, float* d_C, size_t ldc, int* workspace, | |||||
GemmCoord const& problem_size, float alpha, float beta, | |||||
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
cudaStream_t stream, int split_k_slices) { | |||||
static constexpr int kEpilogueElementsPerAccess = 1; | |||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination< | |||||
float, kEpilogueElementsPerAccess, float, float>; | |||||
typename EpilogueOp::Params epilogue{alpha, beta}; | |||||
if (split_k_slices == 1) { | |||||
#define cb(threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ | |||||
warp_k_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | |||||
threadblock_shape.n() == threadblock_n_ && \ | |||||
threadblock_shape.k() == threadblock_k_ && \ | |||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
warp_shape.k() == warp_k_) { \ | |||||
using ThreadBlockShape = \ | |||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_>; \ | |||||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; \ | |||||
using Gemm = cutlass::gemm::device::Gemm< \ | |||||
float, LayoutA, float, LayoutB, float, \ | |||||
cutlass::layout::RowMajor, float, cutlass::arch::OpClassSimt, \ | |||||
cutlass::arch::Sm50, ThreadBlockShape, WarpShape, \ | |||||
InstructionShape, EpilogueOp, \ | |||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, \ | |||||
2>; \ | |||||
return cutlass_matrix_mul_wrapper<Gemm>(d_A, lda, d_B, ldb, d_C, ldc, \ | |||||
workspace, problem_size, \ | |||||
epilogue, stream); \ | |||||
} | |||||
if (!transpose_A && !transpose_B) { | |||||
using LayoutA = cutlass::layout::RowMajor; | |||||
using LayoutB = cutlass::layout::RowMajor; | |||||
DISPATCH(cb) | |||||
} else if (!transpose_A && transpose_B) { | |||||
using LayoutA = cutlass::layout::RowMajor; | |||||
using LayoutB = cutlass::layout::ColumnMajor; | |||||
DISPATCH(cb) | |||||
} else if (transpose_A && !transpose_B) { | |||||
using LayoutA = cutlass::layout::ColumnMajor; | |||||
using LayoutB = cutlass::layout::RowMajor; | |||||
DISPATCH(cb) | |||||
} else { | |||||
megdnn_assert(transpose_A && transpose_B); | |||||
using LayoutA = cutlass::layout::ColumnMajor; | |||||
using LayoutB = cutlass::layout::ColumnMajor; | |||||
DISPATCH(cb) | |||||
} | |||||
#undef cb | |||||
} else { | |||||
#define cb(threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ | |||||
warp_k_) \ | |||||
if (threadblock_shape.m() == threadblock_m_ && \ | |||||
threadblock_shape.n() == threadblock_n_ && \ | |||||
threadblock_shape.k() == threadblock_k_ && \ | |||||
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
warp_shape.k() == warp_k_) { \ | |||||
using ThreadBlockShape = \ | |||||
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
threadblock_k_>; \ | |||||
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; \ | |||||
using Gemm = cutlass::gemm::device::GemmSplitKParallel< \ | |||||
float, LayoutA, float, LayoutB, float, \ | |||||
cutlass::layout::RowMajor, float, cutlass::arch::OpClassSimt, \ | |||||
cutlass::arch::Sm50, ThreadBlockShape, WarpShape, \ | |||||
InstructionShape, EpilogueOp>; \ | |||||
return cutlass_matrix_mul_wrapper<Gemm>( \ | |||||
d_A, lda, d_B, ldb, d_C, ldc, workspace, problem_size, \ | |||||
epilogue, stream, split_k_slices); \ | |||||
} | |||||
if (!transpose_A && !transpose_B) { | |||||
using LayoutA = cutlass::layout::RowMajor; | |||||
using LayoutB = cutlass::layout::RowMajor; | |||||
DISPATCH(cb) | |||||
} else if (!transpose_A && transpose_B) { | |||||
using LayoutA = cutlass::layout::RowMajor; | |||||
using LayoutB = cutlass::layout::ColumnMajor; | |||||
DISPATCH(cb) | |||||
} else if (transpose_A && !transpose_B) { | |||||
using LayoutA = cutlass::layout::ColumnMajor; | |||||
using LayoutB = cutlass::layout::RowMajor; | |||||
DISPATCH(cb) | |||||
} else { | |||||
megdnn_assert(transpose_A && transpose_B); | |||||
using LayoutA = cutlass::layout::ColumnMajor; | |||||
using LayoutB = cutlass::layout::ColumnMajor; | |||||
DISPATCH(cb) | |||||
} | |||||
#undef cb | |||||
} | |||||
} | |||||
#undef DISPATCH | |||||
#endif | |||||
// vim: syntax=cuda.doxygen |
@@ -21,22 +21,6 @@ namespace cutlass_wrapper { | |||||
using GemmCoord = cutlass::gemm::GemmCoord; | using GemmCoord = cutlass::gemm::GemmCoord; | ||||
using BatchedGemmCoord = cutlass::gemm::BatchedGemmCoord; | using BatchedGemmCoord = cutlass::gemm::BatchedGemmCoord; | ||||
template <typename Gemm> | |||||
void cutlass_matrix_mul_wrapper( | |||||
const typename Gemm::ElementA* d_A, size_t lda, | |||||
const typename Gemm::ElementB* d_B, size_t ldb, | |||||
typename Gemm::ElementC* d_C, size_t ldc, int* workspace, | |||||
GemmCoord const& problem_size, | |||||
typename Gemm::EpilogueOutputOp::Params const& epilogue, | |||||
cudaStream_t stream, int split_k_slices = 1); | |||||
void cutlass_matrix_mul_float32_simt( | |||||
const float* d_A, bool transpose_A, size_t lda, const float* d_B, | |||||
bool transpose_B, size_t ldb, float* d_C, size_t ldc, int* workspace, | |||||
GemmCoord const& problem_size, float alpha, float beta, | |||||
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
cudaStream_t stream, int split_k_slices = 1); | |||||
template <typename GemvKernel> | template <typename GemvKernel> | ||||
void cutlass_vector_matrix_mul_batched_strided_wrapper( | void cutlass_vector_matrix_mul_batched_strided_wrapper( | ||||
BatchedGemmCoord const& problem_size, | BatchedGemmCoord const& problem_size, | ||||
@@ -1,57 +0,0 @@ | |||||
/** | |||||
* \file | |||||
* dnn/src/cuda/matrix_mul/matrix_mul_float_simt_cutlass_wrapper.cuinl | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "cutlass/gemm/device/gemm.h" | |||||
#include "cutlass/gemm/device/gemm_splitk_parallel.h" | |||||
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
using namespace cutlass_wrapper; | |||||
template <typename Gemm> | |||||
void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_wrapper( | |||||
const typename Gemm::ElementA* d_A, size_t lda, | |||||
const typename Gemm::ElementB* d_B, size_t ldb, | |||||
typename Gemm::ElementC* d_C, size_t ldc, int* workspace, | |||||
GemmCoord const& problem_size, | |||||
typename Gemm::EpilogueOutputOp::Params const& epilogue, | |||||
cudaStream_t stream, int split_k_slices) { | |||||
using TensorRefA = cutlass::TensorRef<typename Gemm::ElementA const, | |||||
typename Gemm::LayoutA>; | |||||
using TensorRefB = cutlass::TensorRef<typename Gemm::ElementB const, | |||||
typename Gemm::LayoutB>; | |||||
using TensorRefC = cutlass::TensorRef<typename Gemm::ElementC const, | |||||
typename Gemm::LayoutC>; | |||||
using TensorRefD = | |||||
cutlass::TensorRef<typename Gemm::ElementC, typename Gemm::LayoutC>; | |||||
TensorRefA tensor_a{const_cast<typename Gemm::ElementA*>(d_A), | |||||
typename Gemm::LayoutA{static_cast<int>(lda)}}; | |||||
TensorRefB tensor_b{const_cast<typename Gemm::ElementB*>(d_B), | |||||
typename Gemm::LayoutB{static_cast<int>(ldb)}}; | |||||
TensorRefC tensor_c{nullptr, typename Gemm::LayoutC{static_cast<int>(ldc)}}; | |||||
TensorRefD tensor_d{d_C, typename Gemm::LayoutC{static_cast<int>(ldc)}}; | |||||
typename Gemm::Arguments arguments{problem_size, | |||||
tensor_a, | |||||
tensor_b, | |||||
tensor_c, | |||||
tensor_d.non_const_ref(), | |||||
epilogue, | |||||
split_k_slices}; | |||||
Gemm gemm_op; | |||||
cutlass_check(gemm_op.initialize(arguments, workspace)); | |||||
cutlass_check(gemm_op(stream)); | |||||
after_kernel_launch(); | |||||
} | |||||
// vim: syntax=cuda.doxygen |