|
- /**
- * \file dnn/src/cuda/conv_bias/matmul_8x8x32.cpp
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- */
- #include "src/common/conv_bias.h"
- #include "src/cuda/conv_bias/algo.h"
- #include "src/cuda/conv_bias/matmul/im2col_nhwc_int8.cuh"
- #include "src/cuda/utils.cuh"
- #include "src/cuda/utils.h"
-
- using namespace megdnn;
- using namespace cuda;
-
- bool ConvBiasForwardImpl::AlgoMatmul8x8x32::is_available(const SizeArgs& args) const {
- if (args.z_layout->ndim > 0)
- return false;
- if (!is_compute_capability_required(6, 1))
- return false;
-
- if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm ||
- args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4) {
- return false;
- }
-
- auto dst_layout = *args.dst_layout;
- if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) {
- dst_layout.dtype = DType();
- args.opr->check_or_deduce_dtype_fwd(
- args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype);
- }
-
- using NonlineMode = param::ConvBias::NonlineMode;
- auto&& fm = args.filter_meta;
- bool available =
- (args.nonlinear_mode == NonlineMode::IDENTITY ||
- args.nonlinear_mode == NonlineMode::RELU) &&
- ((args.src_layout->dtype == dtype::Int8() &&
- dst_layout.dtype == dtype::Int32() &&
- fm.dtype.enumv() == DTypeEnum::Int8) ||
- (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 &&
- dst_layout.dtype.enumv() == DTypeEnum::QuantizedS32)) &&
- fm.group == 1 && fm.spatial_ndim == 2 &&
- (fm.format == Param::Format::NHWC || fm.format == Param::Format::NCHW4);
- return available;
- };
-
- template <param::ConvBias::Format format>
- WorkspaceBundle ConvBiasForwardImpl::AlgoMatmul8x8x32::get_bundle(
- const SizeArgs& args) const {
- size_t src_unroll_part, filter_reshape_part;
- size_t relayout_src_part = 0, relayout_filter_part = 0, relayout_dst_part = 0;
- auto&& fm = args.filter_meta;
- size_t n, ih, iw, oh, ow, fh, fw, ic, oc;
- n = args.dst_layout->shape[0];
- fh = fm.spatial[0];
- fw = fm.spatial[1];
- if (format == Param::Format::NHWC) {
- oh = args.dst_layout->shape[1];
- ow = args.dst_layout->shape[2];
- ic = args.src_layout->shape[3];
- oc = args.dst_layout->shape[3];
- } else {
- // NCHW4
- ic = args.src_layout->shape[1] * 4;
- ih = args.src_layout->shape[2];
- iw = args.src_layout->shape[3];
- oc = args.dst_layout->shape[1] * 4;
- oh = args.dst_layout->shape[2];
- ow = args.dst_layout->shape[3];
-
- relayout_src_part = n * ic * ih * iw * sizeof(int8_t);
- relayout_filter_part = ic * oc * fh * fw * sizeof(int8_t);
- relayout_dst_part = n * oc * oh * ow * sizeof(int32_t);
- }
- // short for ``leading dimension''
- size_t ld = (fh * fw * ic + 3) & ~3;
- if (need_src_unroll(args)) {
- src_unroll_part = n * oh * ow * ld * sizeof(int8_t);
- } else {
- src_unroll_part = 0;
- }
- if (need_filter_reshape(args)) {
- filter_reshape_part = oc * ld * sizeof(int8_t);
- } else {
- filter_reshape_part = 0;
- }
-
- SmallVector<size_t> sizes = {
- src_unroll_part, filter_reshape_part, relayout_src_part,
- relayout_filter_part, relayout_dst_part};
-
- auto dst_layout = *args.dst_layout;
- if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) {
- dst_layout.dtype = DType();
- args.opr->check_or_deduce_dtype_fwd(
- args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype);
- sizes.push_back(dst_layout.span().dist_byte());
- }
-
- return WorkspaceBundle(nullptr, sizes);
- }
-
- size_t ConvBiasForwardImpl::AlgoMatmul8x8x32::get_workspace_in_bytes(
- const SizeArgs& args) const {
- if (args.filter_meta.format == Param::Format::NHWC) {
- auto bundle = get_bundle<Param::Format::NHWC>(args);
- return bundle.total_size_in_bytes();
- } else {
- // NCHW4
- auto bundle = get_bundle<Param::Format::NCHW4>(args);
- return bundle.total_size_in_bytes();
- }
- }
-
- template <param::ConvBias::Format format>
- void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec_internal(const ExecArgs& args) const {
- auto stream = args.handle->stream();
- auto cublas_handle = args.handle->cublas_handle();
- auto alpha = args.handle->one_device_i32();
- auto beta = args.handle->zero_device_i32();
- auto&& fm = args.filter_meta;
- auto bundle = get_bundle<format>(args);
- bundle.set(args.workspace.raw_ptr);
-
- TensorND src_tensor = *args.src_tensor;
- TensorND dst_tensor = *args.dst_tensor;
- TensorND filter_tensor = *args.filter_tensor;
- if (format == Param::Format::NCHW4) {
- // NCHW4
- auto to_nhwc = [](const TensorLayout& layout, void* raw_ptr) -> TensorND {
- return {raw_ptr,
- {{layout[0], layout[2], layout[3], layout[1] * 4}, layout.dtype}};
- };
- src_tensor = to_nhwc(*args.src_layout, bundle.get(2));
- filter_tensor = to_nhwc(args.filter_tensor->layout, bundle.get(3));
- dst_tensor = to_nhwc(*args.dst_layout, bundle.get(4));
-
- auto relayout = [&](const TensorND& src, void* dst_ptr) {
- auto N = src.layout[0], C = src.layout[1] * 4, H = src.layout[2],
- W = src.layout[3];
- args.handle->relayout_opr()->exec(
- {src.raw_ptr(),
- TensorLayout{
- {N, H, W, C / 4, 4},
- {src.layout.stride[0], src.layout.stride[2],
- src.layout.stride[3], src.layout.stride[1],
- src.layout.stride[4]},
- src.layout.dtype}},
- {dst_ptr, TensorLayout{{N, H, W, C / 4, 4}, src.layout.dtype}});
- };
- relayout(*args.src_tensor, src_tensor.raw_ptr());
- relayout(*args.filter_tensor, filter_tensor.raw_ptr());
- }
-
- size_t N, IH, IW, IC;
- N = src_tensor.layout.shape[0];
- IH = src_tensor.layout.shape[1];
- IW = src_tensor.layout.shape[2];
- IC = src_tensor.layout.shape[3];
-
- auto IWS = src_tensor.layout.stride[2];
- auto FH = fm.spatial[0], FW = fm.spatial[1];
- auto OH = dst_tensor.layout.shape[1], OW = dst_tensor.layout.shape[2],
- OC = dst_tensor.layout.shape[3];
- auto OWS = dst_tensor.layout.stride[2];
- auto PH = fm.padding[0], PW = fm.padding[1];
- auto SH = fm.stride[0], SW = fm.stride[1];
- auto DH = fm.dilation[0], DW = fm.dilation[1];
- auto LD = (FH * FW * IC + 3) & ~3;
-
- int8_t *inp0 = nullptr, *inp1 = nullptr;
- ptrdiff_t inp0_stride = 0, inp1_stride = 0;
-
- if (need_src_unroll(args)) {
- inp0 = static_cast<int8_t*>(bundle.get(0));
- inp0_stride = LD;
- im2col_nhwc_int8(
- src_tensor.compatible_ptr<dt_int8>(), inp0, N, IH, IW, IC, IWS, OH, OW,
- OC, OWS, FH, FW, PH, PW, SH, SW, DH, DW, LD, fm.should_flip, stream);
- } else {
- inp0 = src_tensor.compatible_ptr<dt_int8>();
- inp0_stride = IWS;
- }
- if (need_filter_reshape(args)) {
- // copy (OC, FH*FW*IC) to (OC, FH*FW*IC) with stride=LD
- inp1 = static_cast<int8_t*>(bundle.get(1));
- cuda_check(cudaMemcpy2DAsync(
- inp1, LD * sizeof(int8_t), filter_tensor.raw_ptr(),
- FH * FW * IC * sizeof(int8_t), FH * FW * IC * sizeof(int8_t), OC,
- cudaMemcpyDeviceToDevice, stream));
- inp1_stride = LD;
- } else {
- inp1 = filter_tensor.compatible_ptr<dt_int8>();
- inp1_stride = FH * FW * IC;
- }
- cublas_check(cublasGemmEx(
- cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, OC, N * OH * OW, FH * FW * IC,
- alpha, inp1, CUDA_R_8I, inp1_stride, inp0, CUDA_R_8I, inp0_stride, beta,
- dst_tensor.compatible_ptr<dt_int32>(), CUDA_R_32I, OWS, CUDA_R_32I,
- CUBLAS_GEMM_DFALT));
-
- if (format == Param::Format::NCHW4) {
- args.handle->relayout_opr()->exec(
- {dst_tensor.compatible_ptr<int32_t>(),
- TensorLayout{
- {N, OC / 4, OH, OW, 4},
- {static_cast<ptrdiff_t>(OH * OW * OC), 4,
- static_cast<ptrdiff_t>(OC * OW), static_cast<ptrdiff_t>(OC),
- 1},
- dst_tensor.layout.dtype}},
- *args.dst_tensor);
- }
- }
-
- void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec(const ExecArgs& args) const {
- ExecArgs conv_args = args;
- TensorND conv_dst_tensor = *args.dst_tensor;
- if (args.filter_meta.format == Param::Format::NHWC) {
- auto bundle = get_bundle<Param::Format::NHWC>(args);
- bundle.set(args.workspace.raw_ptr);
- if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) {
- conv_dst_tensor = TensorND{
- bundle.get(bundle.nr_workspace() - 1), args.dst_tensor->layout};
- conv_dst_tensor.layout.dtype = DType();
- args.opr->check_or_deduce_dtype_fwd(
- args.src_layout->dtype, args.filter_layout->dtype,
- conv_dst_tensor.layout.dtype);
- }
- conv_args.dst_tensor = &conv_dst_tensor;
- conv_args.dst_layout = &conv_dst_tensor.layout;
- } else {
- auto bundle = get_bundle<Param::Format::NCHW4>(args);
- bundle.set(args.workspace.raw_ptr);
- if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) {
- conv_dst_tensor = TensorND{
- bundle.get(bundle.nr_workspace() - 1), args.dst_tensor->layout};
- conv_dst_tensor.layout.dtype = DType();
- args.opr->check_or_deduce_dtype_fwd(
- args.src_layout->dtype, args.filter_layout->dtype,
- conv_dst_tensor.layout.dtype);
- }
- conv_args.dst_tensor = &conv_dst_tensor;
- conv_args.dst_layout = &conv_dst_tensor.layout;
- }
-
- if (args.filter_meta.format == Param::Format::NHWC) {
- exec_internal<Param::Format::NHWC>(conv_args);
- } else {
- // NCHW4
- exec_internal<Param::Format::NCHW4>(conv_args);
- }
- handle_bias_and_nonlinear(
- args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor,
- args.bias_tensor);
- }
-
- bool ConvBiasForwardImpl::AlgoMatmul8x8x32::need_filter_reshape(
- const SizeArgs& args) const {
- // cublasGemmEx requires the stride of the filter matrix to be multiples
- // of 4.
- auto&& fm = args.filter_meta;
- size_t ic;
- if (args.filter_meta.format == Param::Format::NHWC) {
- ic = args.src_layout->shape[3];
- } else {
- // NCHW4
- ic = args.src_layout->shape[1] * 4;
- }
- return !(ic * fm.spatial[0] * fm.spatial[1] % 4 == 0);
- }
-
- bool ConvBiasForwardImpl::AlgoMatmul8x8x32::need_src_unroll(
- const SizeArgs& args) const {
- // cublasGemmEx requires the stride of the unrolled src to be multiples
- // of 4.
- size_t stride;
- if (args.filter_meta.format == Param::Format::NHWC) {
- stride = args.src_layout->stride[2];
- } else {
- // NCHW4
- stride = args.src_layout->shape[1] * 4;
- }
-
- auto&& fm = args.filter_meta;
- return !(
- fm.spatial[0] == 1 && fm.spatial[1] == 1 && fm.stride[0] == 1 &&
- fm.stride[1] == 1 && fm.padding[0] == 0 && fm.padding[1] == 0 &&
- stride % 4 == 0);
- }
- // vim: syntax=cpp.doxygen
|