/** * \file dnn/src/cuda/conv_bias/matmul_8x8x32.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "src/common/conv_bias.h" #include "src/cuda/conv_bias/algo.h" #include "src/cuda/conv_bias/matmul/im2col_nhwc_int8.cuh" #include "src/cuda/utils.cuh" #include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; bool ConvBiasForwardImpl::AlgoMatmul8x8x32::is_available(const SizeArgs& args) const { if (args.z_layout->ndim > 0) return false; if (!is_compute_capability_required(6, 1)) return false; if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm || args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4) { return false; } auto dst_layout = *args.dst_layout; if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { dst_layout.dtype = DType(); args.opr->check_or_deduce_dtype_fwd( args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype); } using NonlineMode = param::ConvBias::NonlineMode; auto&& fm = args.filter_meta; bool available = (args.nonlinear_mode == NonlineMode::IDENTITY || args.nonlinear_mode == NonlineMode::RELU) && ((args.src_layout->dtype == dtype::Int8() && dst_layout.dtype == dtype::Int32() && fm.dtype.enumv() == DTypeEnum::Int8) || (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 && dst_layout.dtype.enumv() == DTypeEnum::QuantizedS32)) && fm.group == 1 && fm.spatial_ndim == 2 && (fm.format == Param::Format::NHWC || fm.format == Param::Format::NCHW4); return available; }; template WorkspaceBundle ConvBiasForwardImpl::AlgoMatmul8x8x32::get_bundle( const SizeArgs& args) const { size_t src_unroll_part, filter_reshape_part; size_t relayout_src_part = 0, relayout_filter_part = 0, relayout_dst_part = 0; auto&& fm = args.filter_meta; size_t n, ih, iw, oh, ow, fh, fw, ic, oc; n = args.dst_layout->shape[0]; fh = fm.spatial[0]; fw = fm.spatial[1]; if (format == Param::Format::NHWC) { oh = args.dst_layout->shape[1]; ow = args.dst_layout->shape[2]; ic = args.src_layout->shape[3]; oc = args.dst_layout->shape[3]; } else { // NCHW4 ic = args.src_layout->shape[1] * 4; ih = args.src_layout->shape[2]; iw = args.src_layout->shape[3]; oc = args.dst_layout->shape[1] * 4; oh = args.dst_layout->shape[2]; ow = args.dst_layout->shape[3]; relayout_src_part = n * ic * ih * iw * sizeof(int8_t); relayout_filter_part = ic * oc * fh * fw * sizeof(int8_t); relayout_dst_part = n * oc * oh * ow * sizeof(int32_t); } // short for ``leading dimension'' size_t ld = (fh * fw * ic + 3) & ~3; if (need_src_unroll(args)) { src_unroll_part = n * oh * ow * ld * sizeof(int8_t); } else { src_unroll_part = 0; } if (need_filter_reshape(args)) { filter_reshape_part = oc * ld * sizeof(int8_t); } else { filter_reshape_part = 0; } SmallVector sizes = { src_unroll_part, filter_reshape_part, relayout_src_part, relayout_filter_part, relayout_dst_part}; auto dst_layout = *args.dst_layout; if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { dst_layout.dtype = DType(); args.opr->check_or_deduce_dtype_fwd( args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype); sizes.push_back(dst_layout.span().dist_byte()); } return WorkspaceBundle(nullptr, sizes); } size_t ConvBiasForwardImpl::AlgoMatmul8x8x32::get_workspace_in_bytes( const SizeArgs& args) const { if (args.filter_meta.format == Param::Format::NHWC) { auto bundle = get_bundle(args); return bundle.total_size_in_bytes(); } else { // NCHW4 auto bundle = get_bundle(args); return bundle.total_size_in_bytes(); } } template void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec_internal(const ExecArgs& args) const { auto stream = args.handle->stream(); auto cublas_handle = args.handle->cublas_handle(); auto alpha = args.handle->one_device_i32(); auto beta = args.handle->zero_device_i32(); auto&& fm = args.filter_meta; auto bundle = get_bundle(args); bundle.set(args.workspace.raw_ptr); TensorND src_tensor = *args.src_tensor; TensorND dst_tensor = *args.dst_tensor; TensorND filter_tensor = *args.filter_tensor; if (format == Param::Format::NCHW4) { // NCHW4 auto to_nhwc = [](const TensorLayout& layout, void* raw_ptr) -> TensorND { return {raw_ptr, {{layout[0], layout[2], layout[3], layout[1] * 4}, layout.dtype}}; }; src_tensor = to_nhwc(*args.src_layout, bundle.get(2)); filter_tensor = to_nhwc(args.filter_tensor->layout, bundle.get(3)); dst_tensor = to_nhwc(*args.dst_layout, bundle.get(4)); auto relayout = [&](const TensorND& src, void* dst_ptr) { auto N = src.layout[0], C = src.layout[1] * 4, H = src.layout[2], W = src.layout[3]; args.handle->relayout_opr()->exec( {src.raw_ptr(), TensorLayout{ {N, H, W, C / 4, 4}, {src.layout.stride[0], src.layout.stride[2], src.layout.stride[3], src.layout.stride[1], src.layout.stride[4]}, src.layout.dtype}}, {dst_ptr, TensorLayout{{N, H, W, C / 4, 4}, src.layout.dtype}}); }; relayout(*args.src_tensor, src_tensor.raw_ptr()); relayout(*args.filter_tensor, filter_tensor.raw_ptr()); } size_t N, IH, IW, IC; N = src_tensor.layout.shape[0]; IH = src_tensor.layout.shape[1]; IW = src_tensor.layout.shape[2]; IC = src_tensor.layout.shape[3]; auto IWS = src_tensor.layout.stride[2]; auto FH = fm.spatial[0], FW = fm.spatial[1]; auto OH = dst_tensor.layout.shape[1], OW = dst_tensor.layout.shape[2], OC = dst_tensor.layout.shape[3]; auto OWS = dst_tensor.layout.stride[2]; auto PH = fm.padding[0], PW = fm.padding[1]; auto SH = fm.stride[0], SW = fm.stride[1]; auto DH = fm.dilation[0], DW = fm.dilation[1]; auto LD = (FH * FW * IC + 3) & ~3; int8_t *inp0 = nullptr, *inp1 = nullptr; ptrdiff_t inp0_stride = 0, inp1_stride = 0; if (need_src_unroll(args)) { inp0 = static_cast(bundle.get(0)); inp0_stride = LD; im2col_nhwc_int8( src_tensor.compatible_ptr(), inp0, N, IH, IW, IC, IWS, OH, OW, OC, OWS, FH, FW, PH, PW, SH, SW, DH, DW, LD, fm.should_flip, stream); } else { inp0 = src_tensor.compatible_ptr(); inp0_stride = IWS; } if (need_filter_reshape(args)) { // copy (OC, FH*FW*IC) to (OC, FH*FW*IC) with stride=LD inp1 = static_cast(bundle.get(1)); cuda_check(cudaMemcpy2DAsync( inp1, LD * sizeof(int8_t), filter_tensor.raw_ptr(), FH * FW * IC * sizeof(int8_t), FH * FW * IC * sizeof(int8_t), OC, cudaMemcpyDeviceToDevice, stream)); inp1_stride = LD; } else { inp1 = filter_tensor.compatible_ptr(); inp1_stride = FH * FW * IC; } cublas_check(cublasGemmEx( cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, OC, N * OH * OW, FH * FW * IC, alpha, inp1, CUDA_R_8I, inp1_stride, inp0, CUDA_R_8I, inp0_stride, beta, dst_tensor.compatible_ptr(), CUDA_R_32I, OWS, CUDA_R_32I, CUBLAS_GEMM_DFALT)); if (format == Param::Format::NCHW4) { args.handle->relayout_opr()->exec( {dst_tensor.compatible_ptr(), TensorLayout{ {N, OC / 4, OH, OW, 4}, {static_cast(OH * OW * OC), 4, static_cast(OC * OW), static_cast(OC), 1}, dst_tensor.layout.dtype}}, *args.dst_tensor); } } void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec(const ExecArgs& args) const { ExecArgs conv_args = args; TensorND conv_dst_tensor = *args.dst_tensor; if (args.filter_meta.format == Param::Format::NHWC) { auto bundle = get_bundle(args); bundle.set(args.workspace.raw_ptr); if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { conv_dst_tensor = TensorND{ bundle.get(bundle.nr_workspace() - 1), args.dst_tensor->layout}; conv_dst_tensor.layout.dtype = DType(); args.opr->check_or_deduce_dtype_fwd( args.src_layout->dtype, args.filter_layout->dtype, conv_dst_tensor.layout.dtype); } conv_args.dst_tensor = &conv_dst_tensor; conv_args.dst_layout = &conv_dst_tensor.layout; } else { auto bundle = get_bundle(args); bundle.set(args.workspace.raw_ptr); if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { conv_dst_tensor = TensorND{ bundle.get(bundle.nr_workspace() - 1), args.dst_tensor->layout}; conv_dst_tensor.layout.dtype = DType(); args.opr->check_or_deduce_dtype_fwd( args.src_layout->dtype, args.filter_layout->dtype, conv_dst_tensor.layout.dtype); } conv_args.dst_tensor = &conv_dst_tensor; conv_args.dst_layout = &conv_dst_tensor.layout; } if (args.filter_meta.format == Param::Format::NHWC) { exec_internal(conv_args); } else { // NCHW4 exec_internal(conv_args); } handle_bias_and_nonlinear( args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor, args.bias_tensor); } bool ConvBiasForwardImpl::AlgoMatmul8x8x32::need_filter_reshape( const SizeArgs& args) const { // cublasGemmEx requires the stride of the filter matrix to be multiples // of 4. auto&& fm = args.filter_meta; size_t ic; if (args.filter_meta.format == Param::Format::NHWC) { ic = args.src_layout->shape[3]; } else { // NCHW4 ic = args.src_layout->shape[1] * 4; } return !(ic * fm.spatial[0] * fm.spatial[1] % 4 == 0); } bool ConvBiasForwardImpl::AlgoMatmul8x8x32::need_src_unroll( const SizeArgs& args) const { // cublasGemmEx requires the stride of the unrolled src to be multiples // of 4. size_t stride; if (args.filter_meta.format == Param::Format::NHWC) { stride = args.src_layout->stride[2]; } else { // NCHW4 stride = args.src_layout->shape[1] * 4; } auto&& fm = args.filter_meta; return !( fm.spatial[0] == 1 && fm.spatial[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && fm.padding[0] == 0 && fm.padding[1] == 0 && stride % 4 == 0); } // vim: syntax=cpp.doxygen