diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index d4ad3710..f0c5068a 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -36,7 +36,7 @@ pdef('Axis').add_fields('int32', 'axis', 0) add_enum(Doc('Format', 'convolution data/filter/output format; see ' ':class:`RelayoutFormat` for more details'), 'NCHW', 'NHWC', 'NHWCD4', 'NCHW4', 'NCHW8', 'NCHW32', 'NCHW88', - 'NCHW44','NCHW44_DOT', + 'NCHW44','NCHW44_DOT', Doc('NCHW_WINOGRAD', 'NCHW layout with weights tranformed by winograd'), Doc('NCHW88_WINOGRAD', 'NCHW88 layout with weights tranformed by winograd'), Doc('NCHW44_WINOGRAD', 'NCHW44 layout with weights tranformed by winograd'), @@ -95,7 +95,7 @@ pdef('Axis').add_fields('int32', 'axis', 0) add_enum(Doc('Format', 'convolution data/filter/output format; see ' ':class:`RelayoutFormat` for more details'), 'NCHW', 'NHWC', 'NHWCD4', 'NCHW4', 'NCHW8', 'NCHW32', 'NCHW88', - 'NCHW44','NCHW44_DOT', + 'NCHW44','NCHW44_DOT', Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), @@ -106,7 +106,9 @@ pdef('Axis').add_fields('int32', 'axis', 0) Doc('NCHW_NCHW4_IC_SMALL', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, ' 'output tensor is nchw4 layout, padding c=4'), Doc('CHWN4', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' - 'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.')). + 'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'), + Doc('NCHW64', 'NCHW64 is designed for convolution implementation to utilizing TensorCore ' + 'instructions for 4-bit integers on Nvidia platforms')). add_enum_alias('ComputeMode', 'ConvolutionV1',name_field='compute_mode') ) diff --git a/dnn/src/common/conv_bias.cpp b/dnn/src/common/conv_bias.cpp index 13c77fe7..276d0a60 100644 --- a/dnn/src/common/conv_bias.cpp +++ b/dnn/src/common/conv_bias.cpp @@ -36,28 +36,15 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( const TensorLayout& dst, size_t workspace_in_bytes, const PreprocessedFilter* preprocessed_filter) { megdnn_assert(src.dtype.enumv() == filter.dtype.enumv()); - if (src.dtype.enumv() == DTypeEnum::QuantizedS8) { + // check compatibility of bias's scale + if (src.dtype.category() == DTypeCategory::QUANTIZED) { if (bias.dtype.enumv() == DTypeEnum::QuantizedS32) { - float scale_src = src.dtype.param().scale; - float scale_filter = filter.dtype.param().scale; + float scale_expected = mul_scale(src.dtype, filter.dtype); float scale_bias = bias.dtype.param().scale; - megdnn_assert( - std::abs(scale_src * scale_filter - scale_bias) < 1e-6, - "scale_src: %f scale_filter: %f scale_bias: %f", scale_src, - scale_filter, scale_bias); - } else { - megdnn_assert(bias.dtype.enumv() == DTypeEnum::Float32); - } - } else if (src.dtype.enumv() == DTypeEnum::Quantized8Asymm) { - if (bias.dtype.enumv() == DTypeEnum::QuantizedS32) { - float scale_src = src.dtype.param().scale; - float scale_filter = - filter.dtype.param().scale; - float scale_bias = bias.dtype.param().scale; - megdnn_assert( - std::abs(scale_src * scale_filter - scale_bias) < 1e-6, - "scale_src: %f scale_filter: %f scale_bias: %f", scale_src, - scale_filter, scale_bias); + megdnn_assert(std::abs(scale_expected - scale_bias) < 1e-6, + "scale_src: %f scale_filter: %f scale_bias: %f", + get_scale(src.dtype), get_scale(filter.dtype), + scale_bias); } else { megdnn_assert(bias.dtype.enumv() == DTypeEnum::Float32); } @@ -127,6 +114,13 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( megdnn_assert(bias.shape[2] == 1); megdnn_assert(bias.shape[3] == 1); megdnn_assert(bias.shape[4] == 4); + } else if (param().format == param::ConvBias::Format::NCHW64) { + megdnn_assert(bias.shape[0] == 1); + megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", + bias.to_string().c_str(), dst.to_string().c_str()); + megdnn_assert(bias.shape[2] == 1); + megdnn_assert(bias.shape[3] == 1); + megdnn_assert(bias.shape[4] == 64); } else { megdnn_assert(param().format == param::ConvBias::Format::NHWCD4); megdnn_assert(bias.shape[0] == 1); diff --git a/dnn/src/common/convolution.cpp b/dnn/src/common/convolution.cpp index ec7c70d6..b645cfbf 100644 --- a/dnn/src/common/convolution.cpp +++ b/dnn/src/common/convolution.cpp @@ -370,7 +370,8 @@ void make_canonized_filter_meta_nchwx( param.format == Param::Format::NCHW32 || param.format == Param::Format::NCHW4_NCHW || param.format == Param::Format::NCHW4_NCHW32 || - param.format == Param::Format::NCHW32_NCHW4); + param.format == Param::Format::NCHW32_NCHW4 || + param.format == Param::Format::NCHW64); auto img_ndim = src_ndim - 3; size_t flt_start = 0, flt_spatial_start = 2; if (param.sparse == Param::Sparse::DENSE) { @@ -517,6 +518,9 @@ ConvolutionBase::make_canonized_filter_meta( } else if (param().format == Param::Format::CHWN4) { make_canonized_filter_meta_chwnx<4, Parameter>(src_ndim, filter, param(), ret); + } else if (param().format == Param::Format::NCHW64) { + make_canonized_filter_meta_nchwx<64, Parameter>(src_ndim, filter, + param(), ret); } else { megdnn_assert(param().format == Param::Format::NHWC || param().format == Param::Format::NCHW); @@ -539,6 +543,7 @@ void ConvolutionBase::check_or_deduce_dtype_fwd(DType src, supported_dst_dtype = {dtype::Int32(), dtype::Int16()}; } else if (src.enumv() == DTypeEnum::QuantizedS8 || src.enumv() == DTypeEnum::Quantized8Asymm || + src.enumv() == DTypeEnum::QuantizedS4 || src.enumv() == DTypeEnum::Quantized4Asymm) { supported_dst_dtype.push_back( dtype::QuantizedS32(mul_scale(src, filter))); @@ -614,7 +619,8 @@ ConvolutionBase::deduce_layout_fwd(const TensorLayout& src, param().format == Param::Format::NCHW32 || param().format == Param::Format::NCHW32_NCHW4 || param().format == Param::Format::NCHW88 || - param().format == Param::Format::CHWN4); + param().format == Param::Format::CHWN4 || + param().format == Param::Format::NCHW64); img_dim = src.ndim - 3; if ((param().format == Param::Format::NCHW88 || param().format == Param::Format::NCHW44_DOT || @@ -712,6 +718,15 @@ ConvolutionBase::deduce_layout_fwd(const TensorLayout& src, "but got src %s, filter %s", src.to_string().c_str(), filter.to_string().c_str()); } + if (param().format == Param::Format::NCHW64) { + megdnn_assert(src.ndim == 5 && + (filter.ndim == 5 || filter.ndim == 6) && + src[src.ndim - 1] == 64 && + filter[filter.ndim - 1] == 4, + "NCHW64 require src and filter's ndim is 5 or 6, and " + "last shape is 64 but got src %s, filter %s", + src.to_string().c_str(), filter.to_string().c_str()); + } } megdnn_assert(img_dim == 2, "currently only convolution on 2D image is supported"); @@ -899,6 +914,23 @@ ConvolutionBase::deduce_layout_fwd(const TensorLayout& src, dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]); dst[4] = 4; + } else if (param().format == Param::Format::NCHW64) { + megdnn_assert(src.ndim == 5, + "invalid src ndim for NCHW64, expected=5, got=%zu", + src.ndim); + megdnn_assert(cflt.icpg * cflt.group == src[1] * 64, + "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, + cflt.group); + dst.ndim = src.ndim; + dst[0] = src[0]; + auto oc = cflt.ocpg * cflt.group; + megdnn_assert(oc % 64 == 0); + dst[1] = oc / 64; + dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0], + cflt.stride[0], cflt.padding[0]); + dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], + cflt.stride[1], cflt.padding[1]); + dst[4] = 64; } else { megdnn_assert(param().format == Param::Format::NHWCD4); megdnn_assert(src.ndim == 5, diff --git a/dnn/src/common/utils.cpp b/dnn/src/common/utils.cpp index 7362550f..5609d280 100644 --- a/dnn/src/common/utils.cpp +++ b/dnn/src/common/utils.cpp @@ -245,6 +245,17 @@ float megdnn::mul_scale(DType lhs, DType rhs) { } // clang-format on +float megdnn::get_scale(DType dt) { + megdnn_assert(dt.category() == DTypeCategory::QUANTIZED); +#define cb(_dt) \ + if (dt.enumv() == DTypeTrait<_dt>::enumv) \ + return dt.param<_dt>().scale; + MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) + MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) +#undef cb + megdnn_assert_internal(0); +} + bool megdnn::dtype_almost_equal(DType lhs, DType rhs) { if (lhs.enumv() != rhs.enumv()) return false; diff --git a/dnn/src/common/utils.h b/dnn/src/common/utils.h index 1f306398..11077216 100644 --- a/dnn/src/common/utils.h +++ b/dnn/src/common/utils.h @@ -504,6 +504,8 @@ bool vec_contains(const SmallVector& vec, const T& elem) { float mul_scale(DType lhs, DType rhs); +float get_scale(DType dt); + template dtype convert(stype src, dtype dst, size_t offset); diff --git a/dnn/src/cuda/conv_bias/algo.h b/dnn/src/cuda/conv_bias/algo.h index 526e2ab4..e2597350 100644 --- a/dnn/src/cuda/conv_bias/algo.h +++ b/dnn/src/cuda/conv_bias/algo.h @@ -807,7 +807,6 @@ public: AlgoBatchedMatmul batched_matmul; std::vector int8_nchw4_dotprod; AlgoInt8CHWN4DotProdImplicitGemm int8_chwn4_dotprod; -<<<<<<< HEAD #if CUDA_VERSION >= 10000 AlgoQUInt4x4x32WMMA wmma_quint4x4x32; std::vector int8_chwn4_imma; diff --git a/dnn/src/cuda/conv_bias/conv_bias_int8.cuh b/dnn/src/cuda/conv_bias/conv_bias_int8.cuh index 62968ea4..a5f4531f 100644 --- a/dnn/src/cuda/conv_bias/conv_bias_int8.cuh +++ b/dnn/src/cuda/conv_bias/conv_bias_int8.cuh @@ -150,4 +150,12 @@ void do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width( UNPACK_CONV_PARAMETER(_filter_meta, _param); \ MARK_USED_VAR +#define UNPACK_CONV_BIAS_NCHW64_PARAM(_src, _filter_meta, _dst, _param) \ + using Format = param::ConvBias::Format; \ + megdnn_assert(_param.format == Format::NCHW64); \ + size_t n = (_src)[0], ci = (_src)[1] * 64, hi = (_src)[2], wi = (_src)[3]; \ + size_t co = (_dst)[1] * 64, ho = (_dst)[2], wo = (_dst)[3]; \ + UNPACK_CONV_PARAMETER(_filter_meta, _param); \ + MARK_USED_VAR + // vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp b/dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp new file mode 100644 index 00000000..fbc851c5 --- /dev/null +++ b/dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp @@ -0,0 +1,302 @@ +/** + * \file dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "./algo.h" +#include "src/cuda/conv_bias/sass_helper.cuh" +#include "src/cuda/sass_loader.h" +#include "src/cuda/utils.h" +#include "src/common/conv_bias.h" + +using namespace megdnn; +using namespace cuda; +using namespace sass; + +namespace { +#if !MEGDNN_TEGRA_X1 +// all stride are in bytes +void compute_conv2d_offset(size_t fh, size_t fw, size_t ics, size_t ihs, + Conv2dConstantOffset& constant_offset) { + constexpr int interleaved = 64; + constexpr int size_bits = 4; + constexpr int threablock_k = 128; + constexpr int inc_step = threablock_k / interleaved; + size_t i = 0; + int* s32 = reinterpret_cast(&(constant_offset.c_offset[0])); + for (; i < inc_step; i++) { + int c = i / (fh * fw); + int khkw = i % (fh * fw); + int kh = khkw / fw; + int kw = khkw % fw; + s32[2 * i] = c * ics + kh * ihs + kw * interleaved * size_bits / 8; + int8_t* s8 = reinterpret_cast(&(s32[2 * i + 1])); + s8[0] = kh; + s8[1] = kw; + s8[2] = -kh; + s8[3] = -kw; + } + for (; i < (inc_step + fh * fw * inc_step); i++) { + int c = i / (fh * fw); + int khkw = i % (fh * fw); + int kh = khkw / fw; + int kw = khkw % fw; + s32[2 * i] = c * ics + kh * ihs + kw * interleaved * size_bits / 8; + int8_t* s8 = reinterpret_cast(&(s32[2 * i + 1])); + s8[0] = kh; + s8[1] = kw; + s8[2] = -kh; + s8[3] = -kw; + int i_ = i - inc_step; + c = i_ / (fh * fw); + khkw = i_ % (fh * fw); + kh = khkw / fw; + kw = khkw % fw; + s32[2 * i] -= c * ics + kh * ihs + kw * interleaved * size_bits / 8; + } +} +#endif +}; // namespace + +std::string ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::kernel_key( + const SizeArgs& args) const { + std::string kernel_key; + using NonlineMode = Param::NonlineMode; + auto&& param = args.opr->param(); + if (args.z_layout->ndim > 0) { + kernel_key = + ssprintf("%s_conv_bias_int4_fuse_z_imma_ldg16_%ux%u", + current_device_arch_name(), m_tile_nhw, m_tile_oc); + } else { + kernel_key = + ssprintf("%s_conv_bias_int4_imma_ldg16_%ux%u", + current_device_arch_name(), m_tile_nhw, m_tile_oc); + } + if (param.nonlineMode == NonlineMode::H_SWISH) { + kernel_key += "_hswish"; + } else { + megdnn_assert(param.nonlineMode == NonlineMode::RELU || + param.nonlineMode == NonlineMode::IDENTITY); + kernel_key += "_relu"; + } + return kernel_key; +} + +bool ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::is_available( + const SizeArgs& args) const { + if (args.bias_layout->ndim <= 0) + return false; + using Param = param::ConvBias; + using Format = Param::Format; + using Sparse = Param::Sparse; + using Mode = Param::Mode; + bool available = true; + auto&& param = args.opr->param(); + auto&& fm = args.filter_meta; + if (!check_bias_share_in_channel(*(args.bias_layout), param.format)) + return false; + if (param.format != Format::NCHW64) + return false; + UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout), + param); + // TODO support group conv + available &= param.sparse == Sparse::DENSE; + // mode must be cross correlation + available &= param.mode == Mode::CROSS_CORRELATION; + // check data type + auto src_dtype = args.src_layout->dtype, + filter_dtype = args.filter_layout->dtype, + bias_dtype = args.bias_layout->dtype, + dst_dtype = args.dst_layout->dtype; + available &= (src_dtype.enumv() == DTypeEnum::QuantizedS4 && + filter_dtype.enumv() == DTypeEnum::QuantizedS4 && + bias_dtype.enumv() == DTypeEnum::QuantizedS32 && + dst_dtype.enumv() == DTypeEnum::QuantizedS4); + // TODO: support dialtion + available &= dh == 1 && dw == 1; + // ensure precomputed offsets are positive integers + available &= hi >= fh && wi >= fw; + // only support sm_75 or later, platform should have tensorcore int8 + // support + available &= is_compute_capability_required(7, 5); + // param buffer size is 4K, use 3K to store precomputed offset, fh * fw <= + // (3*1024/4/2/2) - 1 + available &= fh * fw <= 191; + return available; +} + +size_t +ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::get_workspace_in_bytes( + const SizeArgs& args) const { + if (args.preprocessed_filter == nullptr) { + return args.filter_layout->span().dist_byte() + + args.bias_layout->span().dist_byte(); + } + return 0_z; +} + +void ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::exec( + const ExecArgs& args) const { +#if MEGDNN_TEGRA_X1 + megdnn_throw("sass kernel is disabled at compile time for TX1"); +#else + using Format = Param::Format; + auto&& param = args.opr->param(); + auto&& fm = args.filter_meta; + UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout), + param); + auto&& stream = cuda_stream(args.opr->handle()); + constexpr int interleaved = 64; + + void* bias_ptr = nullptr; + void* filter_ptr = nullptr; + if (args.preprocessed_filter) { + megdnn_assert(args.preprocessed_filter->tensors.size() == 2); + filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; + bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr; + } else { + // reorder filter and bias + filter_ptr = reinterpret_cast(args.workspace.raw_ptr); + bias_ptr = + reinterpret_cast(args.workspace.raw_ptr + + args.filter_layout->span().dist_byte()); + reorder_imma_filter_bias<4, 64>( + reinterpret_cast(filter_ptr), + reinterpret_cast(bias_ptr), + args.filter_tensor->compatible_ptr(), + args.bias_tensor->compatible_ptr(), co, ci, fh, fw, + stream); + } + + uint32_t u32_n = n, u32_ci = ci, u32_hi = hi, u32_wi = wi, u32_fh = fh, + u32_fw = fw, u32_sh = sh, u32_sw = sw, u32_ph = ph, u32_pw = pw, + u32_co = co, u32_ho = ho, u32_wo = wo; + Conv2dInt4Param kern_param(u32_n, u32_ci, u32_hi, u32_wi, u32_fh, u32_fw, + u32_sh, u32_sw, u32_ph, u32_pw, u32_co, u32_ho, + u32_wo, interleaved); + + Conv2dConstantOffset kern_coffset; + compute_conv2d_offset(fh, fw, kern_param.ics, kern_param.ihs, kern_coffset); + // The starting address of Turing param buffer is c[0x0][0x160] + kern_coffset.c_offset_param.begin = param_buffer_start_address(); + kern_coffset.c_offset_param.size = 16 * (1 + fh * fw); + kern_coffset.c_offset_param.max = 16 * fh * fw; + kern_coffset.c_offset_param.rewind = 16 * (1 - fh * fw); + + auto kern_key = kernel_key(args); + float src_scale = args.src_layout->dtype.param().scale, + filter_scale = + args.filter_layout->dtype.param().scale, + bias_scale = + args.bias_layout->dtype.param().scale, + dst_scale = args.dst_layout->dtype.param().scale; + float alpha = src_scale * filter_scale / dst_scale, + beta = bias_scale / dst_scale; + float inv_dst_scale = 1.f / dst_scale; + + unsigned int tx = m_threads, ty = 1; + unsigned int gridx = div_ceil( + static_cast(n * ho * wo), m_tile_nhw); + unsigned int gridy = + div_ceil(static_cast(co), m_tile_oc); + void* src_ptr = const_cast(args.src_tensor->raw_ptr); + void* dst_ptr = const_cast(args.dst_tensor->raw_ptr); + + using NonlineMode = Param::NonlineMode; + auto&& kernel = SASSKernelLoader::instance().get_kernel(kern_key, kern_key); + if (args.z_layout->ndim > 0) { + void* z_ptr = const_cast(args.z_tensor->raw_ptr); + float z_scale = args.z_layout->dtype.param().scale; + float gamma = z_scale / dst_scale; + std::vector params = {&src_ptr, &filter_ptr, &bias_ptr, &z_ptr, + &dst_ptr, &alpha, &beta, &gamma}; + kern_coffset.c_offset_param.begin += + sizeof(src_ptr) + sizeof(filter_ptr) + sizeof(bias_ptr) + + sizeof(z_ptr) + sizeof(dst_ptr) + sizeof(alpha) + sizeof(beta) + + sizeof(gamma); + + uint32_t relu = param.nonlineMode == NonlineMode::RELU ? 1 : 0; + if (param.nonlineMode == NonlineMode::H_SWISH) { + params.push_back(&dst_scale); + params.push_back(&inv_dst_scale); + kern_coffset.c_offset_param.begin += + sizeof(dst_scale) + sizeof(inv_dst_scale); + } else { + params.push_back(&relu); + kern_coffset.c_offset_param.begin += sizeof(relu); + } + params.push_back(&kern_param); + kern_coffset.c_offset_param.begin += sizeof(kern_param); + kern_coffset.c_offset_param.begin += + sizeof(kern_coffset.c_offset_param); + kern_coffset.c_offset_param.max += kern_coffset.c_offset_param.begin; + params.push_back(&kern_coffset); + cucheck(cuLaunchKernel(kernel, gridx, gridy, 1, tx, ty, 1, 0, stream, + params.data(), 0)); + } else { + std::vector params = {&src_ptr, &filter_ptr, &bias_ptr, + &dst_ptr, &alpha, &beta}; + + kern_coffset.c_offset_param.begin += + sizeof(src_ptr) + sizeof(filter_ptr) + sizeof(bias_ptr) + + sizeof(dst_ptr) + sizeof(alpha) + sizeof(beta); + + uint32_t relu = param.nonlineMode == NonlineMode::RELU ? 1 : 0; + if (param.nonlineMode == NonlineMode::H_SWISH) { + params.push_back(&dst_scale); + params.push_back(&inv_dst_scale); + kern_coffset.c_offset_param.begin += + sizeof(dst_scale) + sizeof(inv_dst_scale); + } else { + params.push_back(&relu); + kern_coffset.c_offset_param.begin += sizeof(relu); + } + params.push_back(&kern_param); + kern_coffset.c_offset_param.begin += sizeof(kern_param); + kern_coffset.c_offset_param.begin += + sizeof(kern_coffset.c_offset_param); + kern_coffset.c_offset_param.max += kern_coffset.c_offset_param.begin; + params.push_back(&kern_coffset); + cucheck(cuLaunchKernel(kernel, gridx, gridy, 1, tx, ty, 1, 0, stream, + params.data(), 0)); + } + after_kernel_launch(); +#endif +} + +size_t ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm:: + get_preprocess_workspace_in_bytes(const SizeArgs& args) const { + return 0_z; +} + +SmallVector ConvBiasForwardImpl:: + AlgoSASSInt4NCHW64IMMAImplicitGemm::deduce_preprocessed_filter_layout( + const SizeArgs& args) const { + return {args.filter_layout->collapse_contiguous(), + args.bias_layout->collapse_contiguous()}; +} + +void ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::exec_preprocess( + const ExecArgs& args) const { + using Format = Param::Format; + auto&& param = args.opr->param(); + auto&& fm = args.filter_meta; + UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout), + param); + auto&& stream = cuda_stream(args.opr->handle()); + reorder_imma_filter_bias<4, 64>( + args.preprocessed_filter->tensors[0].compatible_ptr(), + args.preprocessed_filter->tensors[1].compatible_ptr(), + args.filter_tensor->compatible_ptr(), + args.bias_tensor->compatible_ptr(), co, ci, fh, fw, + stream); +} + +// vim: syntax=cpp.doxygen