Browse Source

feat(dnn/cuda): add conv bias impl for int4 data type using sass language

GitOrigin-RevId: ae3d3e1c98
release-1.5
Megvii Engine Team 4 years ago
parent
commit
ed92207585
8 changed files with 376 additions and 26 deletions
  1. +5
    -3
      dnn/scripts/opr_param_defs.py
  2. +14
    -20
      dnn/src/common/conv_bias.cpp
  3. +34
    -2
      dnn/src/common/convolution.cpp
  4. +11
    -0
      dnn/src/common/utils.cpp
  5. +2
    -0
      dnn/src/common/utils.h
  6. +0
    -1
      dnn/src/cuda/conv_bias/algo.h
  7. +8
    -0
      dnn/src/cuda/conv_bias/conv_bias_int8.cuh
  8. +302
    -0
      dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp

+ 5
- 3
dnn/scripts/opr_param_defs.py View File

@@ -36,7 +36,7 @@ pdef('Axis').add_fields('int32', 'axis', 0)
add_enum(Doc('Format', 'convolution data/filter/output format; see '
':class:`RelayoutFormat` for more details'),
'NCHW', 'NHWC', 'NHWCD4', 'NCHW4', 'NCHW8', 'NCHW32', 'NCHW88',
'NCHW44','NCHW44_DOT',
'NCHW44','NCHW44_DOT',
Doc('NCHW_WINOGRAD', 'NCHW layout with weights tranformed by winograd'),
Doc('NCHW88_WINOGRAD', 'NCHW88 layout with weights tranformed by winograd'),
Doc('NCHW44_WINOGRAD', 'NCHW44 layout with weights tranformed by winograd'),
@@ -95,7 +95,7 @@ pdef('Axis').add_fields('int32', 'axis', 0)
add_enum(Doc('Format', 'convolution data/filter/output format; see '
':class:`RelayoutFormat` for more details'),
'NCHW', 'NHWC', 'NHWCD4', 'NCHW4', 'NCHW8', 'NCHW32', 'NCHW88',
'NCHW44','NCHW44_DOT',
'NCHW44','NCHW44_DOT',
Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'),
Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'),
Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'),
@@ -106,7 +106,9 @@ pdef('Axis').add_fields('int32', 'axis', 0)
Doc('NCHW_NCHW4_IC_SMALL', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, '
'output tensor is nchw4 layout, padding c=4'),
Doc('CHWN4', 'CHWN4 is currently only used on Nvidia platform for fast implementation '
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.')).
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'),
Doc('NCHW64', 'NCHW64 is designed for convolution implementation to utilizing TensorCore '
'instructions for 4-bit integers on Nvidia platforms')).
add_enum_alias('ComputeMode', 'ConvolutionV1',name_field='compute_mode')
)



+ 14
- 20
dnn/src/common/conv_bias.cpp View File

@@ -36,28 +36,15 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
const TensorLayout& dst, size_t workspace_in_bytes,
const PreprocessedFilter* preprocessed_filter) {
megdnn_assert(src.dtype.enumv() == filter.dtype.enumv());
if (src.dtype.enumv() == DTypeEnum::QuantizedS8) {
// check compatibility of bias's scale
if (src.dtype.category() == DTypeCategory::QUANTIZED) {
if (bias.dtype.enumv() == DTypeEnum::QuantizedS32) {
float scale_src = src.dtype.param<dtype::QuantizedS8>().scale;
float scale_filter = filter.dtype.param<dtype::QuantizedS8>().scale;
float scale_expected = mul_scale(src.dtype, filter.dtype);
float scale_bias = bias.dtype.param<dtype::QuantizedS32>().scale;
megdnn_assert(
std::abs(scale_src * scale_filter - scale_bias) < 1e-6,
"scale_src: %f scale_filter: %f scale_bias: %f", scale_src,
scale_filter, scale_bias);
} else {
megdnn_assert(bias.dtype.enumv() == DTypeEnum::Float32);
}
} else if (src.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
if (bias.dtype.enumv() == DTypeEnum::QuantizedS32) {
float scale_src = src.dtype.param<dtype::Quantized8Asymm>().scale;
float scale_filter =
filter.dtype.param<dtype::Quantized8Asymm>().scale;
float scale_bias = bias.dtype.param<dtype::QuantizedS32>().scale;
megdnn_assert(
std::abs(scale_src * scale_filter - scale_bias) < 1e-6,
"scale_src: %f scale_filter: %f scale_bias: %f", scale_src,
scale_filter, scale_bias);
megdnn_assert(std::abs(scale_expected - scale_bias) < 1e-6,
"scale_src: %f scale_filter: %f scale_bias: %f",
get_scale(src.dtype), get_scale(filter.dtype),
scale_bias);
} else {
megdnn_assert(bias.dtype.enumv() == DTypeEnum::Float32);
}
@@ -127,6 +114,13 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
megdnn_assert(bias.shape[2] == 1);
megdnn_assert(bias.shape[3] == 1);
megdnn_assert(bias.shape[4] == 4);
} else if (param().format == param::ConvBias::Format::NCHW64) {
megdnn_assert(bias.shape[0] == 1);
megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
bias.to_string().c_str(), dst.to_string().c_str());
megdnn_assert(bias.shape[2] == 1);
megdnn_assert(bias.shape[3] == 1);
megdnn_assert(bias.shape[4] == 64);
} else {
megdnn_assert(param().format == param::ConvBias::Format::NHWCD4);
megdnn_assert(bias.shape[0] == 1);


+ 34
- 2
dnn/src/common/convolution.cpp View File

@@ -370,7 +370,8 @@ void make_canonized_filter_meta_nchwx(
param.format == Param::Format::NCHW32 ||
param.format == Param::Format::NCHW4_NCHW ||
param.format == Param::Format::NCHW4_NCHW32 ||
param.format == Param::Format::NCHW32_NCHW4);
param.format == Param::Format::NCHW32_NCHW4 ||
param.format == Param::Format::NCHW64);
auto img_ndim = src_ndim - 3;
size_t flt_start = 0, flt_spatial_start = 2;
if (param.sparse == Param::Sparse::DENSE) {
@@ -517,6 +518,9 @@ ConvolutionBase<Parameter>::make_canonized_filter_meta(
} else if (param().format == Param::Format::CHWN4) {
make_canonized_filter_meta_chwnx<4, Parameter>(src_ndim, filter,
param(), ret);
} else if (param().format == Param::Format::NCHW64) {
make_canonized_filter_meta_nchwx<64, Parameter>(src_ndim, filter,
param(), ret);
} else {
megdnn_assert(param().format == Param::Format::NHWC ||
param().format == Param::Format::NCHW);
@@ -539,6 +543,7 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(DType src,
supported_dst_dtype = {dtype::Int32(), dtype::Int16()};
} else if (src.enumv() == DTypeEnum::QuantizedS8 ||
src.enumv() == DTypeEnum::Quantized8Asymm ||
src.enumv() == DTypeEnum::QuantizedS4 ||
src.enumv() == DTypeEnum::Quantized4Asymm) {
supported_dst_dtype.push_back(
dtype::QuantizedS32(mul_scale(src, filter)));
@@ -614,7 +619,8 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
param().format == Param::Format::NCHW32 ||
param().format == Param::Format::NCHW32_NCHW4 ||
param().format == Param::Format::NCHW88 ||
param().format == Param::Format::CHWN4);
param().format == Param::Format::CHWN4 ||
param().format == Param::Format::NCHW64);
img_dim = src.ndim - 3;
if ((param().format == Param::Format::NCHW88 ||
param().format == Param::Format::NCHW44_DOT ||
@@ -712,6 +718,15 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
"but got src %s, filter %s",
src.to_string().c_str(), filter.to_string().c_str());
}
if (param().format == Param::Format::NCHW64) {
megdnn_assert(src.ndim == 5 &&
(filter.ndim == 5 || filter.ndim == 6) &&
src[src.ndim - 1] == 64 &&
filter[filter.ndim - 1] == 4,
"NCHW64 require src and filter's ndim is 5 or 6, and "
"last shape is 64 but got src %s, filter %s",
src.to_string().c_str(), filter.to_string().c_str());
}
}
megdnn_assert(img_dim == 2,
"currently only convolution on 2D image is supported");
@@ -899,6 +914,23 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1],
cflt.stride[1], cflt.padding[1]);
dst[4] = 4;
} else if (param().format == Param::Format::NCHW64) {
megdnn_assert(src.ndim == 5,
"invalid src ndim for NCHW64, expected=5, got=%zu",
src.ndim);
megdnn_assert(cflt.icpg * cflt.group == src[1] * 64,
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg,
cflt.group);
dst.ndim = src.ndim;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;
megdnn_assert(oc % 64 == 0);
dst[1] = oc / 64;
dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0],
cflt.stride[0], cflt.padding[0]);
dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1],
cflt.stride[1], cflt.padding[1]);
dst[4] = 64;
} else {
megdnn_assert(param().format == Param::Format::NHWCD4);
megdnn_assert(src.ndim == 5,


+ 11
- 0
dnn/src/common/utils.cpp View File

@@ -245,6 +245,17 @@ float megdnn::mul_scale(DType lhs, DType rhs) {
}
// clang-format on

float megdnn::get_scale(DType dt) {
megdnn_assert(dt.category() == DTypeCategory::QUANTIZED);
#define cb(_dt) \
if (dt.enumv() == DTypeTrait<_dt>::enumv) \
return dt.param<_dt>().scale;
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
#undef cb
megdnn_assert_internal(0);
}

bool megdnn::dtype_almost_equal(DType lhs, DType rhs) {
if (lhs.enumv() != rhs.enumv())
return false;


+ 2
- 0
dnn/src/common/utils.h View File

@@ -504,6 +504,8 @@ bool vec_contains(const SmallVector<T>& vec, const T& elem) {

float mul_scale(DType lhs, DType rhs);

float get_scale(DType dt);

template <typename stype, typename dtype>
dtype convert(stype src, dtype dst, size_t offset);



+ 0
- 1
dnn/src/cuda/conv_bias/algo.h View File

@@ -807,7 +807,6 @@ public:
AlgoBatchedMatmul batched_matmul;
std::vector<AlgoInt8NCHW4DotProdImplicitGemm> int8_nchw4_dotprod;
AlgoInt8CHWN4DotProdImplicitGemm int8_chwn4_dotprod;
<<<<<<< HEAD
#if CUDA_VERSION >= 10000
AlgoQUInt4x4x32WMMA wmma_quint4x4x32;
std::vector<AlgoInt8CHWN4IMMAImplicitGemm> int8_chwn4_imma;


+ 8
- 0
dnn/src/cuda/conv_bias/conv_bias_int8.cuh View File

@@ -150,4 +150,12 @@ void do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width(
UNPACK_CONV_PARAMETER(_filter_meta, _param); \
MARK_USED_VAR

#define UNPACK_CONV_BIAS_NCHW64_PARAM(_src, _filter_meta, _dst, _param) \
using Format = param::ConvBias::Format; \
megdnn_assert(_param.format == Format::NCHW64); \
size_t n = (_src)[0], ci = (_src)[1] * 64, hi = (_src)[2], wi = (_src)[3]; \
size_t co = (_dst)[1] * 64, ho = (_dst)[2], wo = (_dst)[3]; \
UNPACK_CONV_PARAMETER(_filter_meta, _param); \
MARK_USED_VAR

// vim: syntax=cuda.doxygen

+ 302
- 0
dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp View File

@@ -0,0 +1,302 @@
/**
* \file dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "./algo.h"
#include "src/cuda/conv_bias/sass_helper.cuh"
#include "src/cuda/sass_loader.h"
#include "src/cuda/utils.h"
#include "src/common/conv_bias.h"

using namespace megdnn;
using namespace cuda;
using namespace sass;

namespace {
#if !MEGDNN_TEGRA_X1
// all stride are in bytes
void compute_conv2d_offset(size_t fh, size_t fw, size_t ics, size_t ihs,
Conv2dConstantOffset& constant_offset) {
constexpr int interleaved = 64;
constexpr int size_bits = 4;
constexpr int threablock_k = 128;
constexpr int inc_step = threablock_k / interleaved;
size_t i = 0;
int* s32 = reinterpret_cast<int*>(&(constant_offset.c_offset[0]));
for (; i < inc_step; i++) {
int c = i / (fh * fw);
int khkw = i % (fh * fw);
int kh = khkw / fw;
int kw = khkw % fw;
s32[2 * i] = c * ics + kh * ihs + kw * interleaved * size_bits / 8;
int8_t* s8 = reinterpret_cast<int8_t*>(&(s32[2 * i + 1]));
s8[0] = kh;
s8[1] = kw;
s8[2] = -kh;
s8[3] = -kw;
}
for (; i < (inc_step + fh * fw * inc_step); i++) {
int c = i / (fh * fw);
int khkw = i % (fh * fw);
int kh = khkw / fw;
int kw = khkw % fw;
s32[2 * i] = c * ics + kh * ihs + kw * interleaved * size_bits / 8;
int8_t* s8 = reinterpret_cast<int8_t*>(&(s32[2 * i + 1]));
s8[0] = kh;
s8[1] = kw;
s8[2] = -kh;
s8[3] = -kw;
int i_ = i - inc_step;
c = i_ / (fh * fw);
khkw = i_ % (fh * fw);
kh = khkw / fw;
kw = khkw % fw;
s32[2 * i] -= c * ics + kh * ihs + kw * interleaved * size_bits / 8;
}
}
#endif
}; // namespace

std::string ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::kernel_key(
const SizeArgs& args) const {
std::string kernel_key;
using NonlineMode = Param::NonlineMode;
auto&& param = args.opr->param();
if (args.z_layout->ndim > 0) {
kernel_key =
ssprintf("%s_conv_bias_int4_fuse_z_imma_ldg16_%ux%u",
current_device_arch_name(), m_tile_nhw, m_tile_oc);
} else {
kernel_key =
ssprintf("%s_conv_bias_int4_imma_ldg16_%ux%u",
current_device_arch_name(), m_tile_nhw, m_tile_oc);
}
if (param.nonlineMode == NonlineMode::H_SWISH) {
kernel_key += "_hswish";
} else {
megdnn_assert(param.nonlineMode == NonlineMode::RELU ||
param.nonlineMode == NonlineMode::IDENTITY);
kernel_key += "_relu";
}
return kernel_key;
}

bool ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::is_available(
const SizeArgs& args) const {
if (args.bias_layout->ndim <= 0)
return false;
using Param = param::ConvBias;
using Format = Param::Format;
using Sparse = Param::Sparse;
using Mode = Param::Mode;
bool available = true;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
if (!check_bias_share_in_channel(*(args.bias_layout), param.format))
return false;
if (param.format != Format::NCHW64)
return false;
UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout),
param);
// TODO support group conv
available &= param.sparse == Sparse::DENSE;
// mode must be cross correlation
available &= param.mode == Mode::CROSS_CORRELATION;
// check data type
auto src_dtype = args.src_layout->dtype,
filter_dtype = args.filter_layout->dtype,
bias_dtype = args.bias_layout->dtype,
dst_dtype = args.dst_layout->dtype;
available &= (src_dtype.enumv() == DTypeEnum::QuantizedS4 &&
filter_dtype.enumv() == DTypeEnum::QuantizedS4 &&
bias_dtype.enumv() == DTypeEnum::QuantizedS32 &&
dst_dtype.enumv() == DTypeEnum::QuantizedS4);
// TODO: support dialtion
available &= dh == 1 && dw == 1;
// ensure precomputed offsets are positive integers
available &= hi >= fh && wi >= fw;
// only support sm_75 or later, platform should have tensorcore int8
// support
available &= is_compute_capability_required(7, 5);
// param buffer size is 4K, use 3K to store precomputed offset, fh * fw <=
// (3*1024/4/2/2) - 1
available &= fh * fw <= 191;
return available;
}

size_t
ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::get_workspace_in_bytes(
const SizeArgs& args) const {
if (args.preprocessed_filter == nullptr) {
return args.filter_layout->span().dist_byte() +
args.bias_layout->span().dist_byte();
}
return 0_z;
}

void ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::exec(
const ExecArgs& args) const {
#if MEGDNN_TEGRA_X1
megdnn_throw("sass kernel is disabled at compile time for TX1");
#else
using Format = Param::Format;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout),
param);
auto&& stream = cuda_stream(args.opr->handle());
constexpr int interleaved = 64;

void* bias_ptr = nullptr;
void* filter_ptr = nullptr;
if (args.preprocessed_filter) {
megdnn_assert(args.preprocessed_filter->tensors.size() == 2);
filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr;
bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr;
} else {
// reorder filter and bias
filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr);
bias_ptr =
reinterpret_cast<void*>(args.workspace.raw_ptr +
args.filter_layout->span().dist_byte());
reorder_imma_filter_bias<4, 64>(
reinterpret_cast<int8_t*>(filter_ptr),
reinterpret_cast<int32_t*>(bias_ptr),
args.filter_tensor->compatible_ptr<int8_t>(),
args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw,
stream);
}

uint32_t u32_n = n, u32_ci = ci, u32_hi = hi, u32_wi = wi, u32_fh = fh,
u32_fw = fw, u32_sh = sh, u32_sw = sw, u32_ph = ph, u32_pw = pw,
u32_co = co, u32_ho = ho, u32_wo = wo;
Conv2dInt4Param kern_param(u32_n, u32_ci, u32_hi, u32_wi, u32_fh, u32_fw,
u32_sh, u32_sw, u32_ph, u32_pw, u32_co, u32_ho,
u32_wo, interleaved);

Conv2dConstantOffset kern_coffset;
compute_conv2d_offset(fh, fw, kern_param.ics, kern_param.ihs, kern_coffset);
// The starting address of Turing param buffer is c[0x0][0x160]
kern_coffset.c_offset_param.begin = param_buffer_start_address();
kern_coffset.c_offset_param.size = 16 * (1 + fh * fw);
kern_coffset.c_offset_param.max = 16 * fh * fw;
kern_coffset.c_offset_param.rewind = 16 * (1 - fh * fw);

auto kern_key = kernel_key(args);
float src_scale = args.src_layout->dtype.param<dtype::QuantizedS4>().scale,
filter_scale =
args.filter_layout->dtype.param<dtype::QuantizedS4>().scale,
bias_scale =
args.bias_layout->dtype.param<dtype::QuantizedS32>().scale,
dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale;
float alpha = src_scale * filter_scale / dst_scale,
beta = bias_scale / dst_scale;
float inv_dst_scale = 1.f / dst_scale;

unsigned int tx = m_threads, ty = 1;
unsigned int gridx = div_ceil<unsigned int>(
static_cast<unsigned int>(n * ho * wo), m_tile_nhw);
unsigned int gridy =
div_ceil<unsigned int>(static_cast<unsigned int>(co), m_tile_oc);
void* src_ptr = const_cast<void*>(args.src_tensor->raw_ptr);
void* dst_ptr = const_cast<void*>(args.dst_tensor->raw_ptr);

using NonlineMode = Param::NonlineMode;
auto&& kernel = SASSKernelLoader::instance().get_kernel(kern_key, kern_key);
if (args.z_layout->ndim > 0) {
void* z_ptr = const_cast<void*>(args.z_tensor->raw_ptr);
float z_scale = args.z_layout->dtype.param<dtype::QuantizedS4>().scale;
float gamma = z_scale / dst_scale;
std::vector<void*> params = {&src_ptr, &filter_ptr, &bias_ptr, &z_ptr,
&dst_ptr, &alpha, &beta, &gamma};
kern_coffset.c_offset_param.begin +=
sizeof(src_ptr) + sizeof(filter_ptr) + sizeof(bias_ptr) +
sizeof(z_ptr) + sizeof(dst_ptr) + sizeof(alpha) + sizeof(beta) +
sizeof(gamma);

uint32_t relu = param.nonlineMode == NonlineMode::RELU ? 1 : 0;
if (param.nonlineMode == NonlineMode::H_SWISH) {
params.push_back(&dst_scale);
params.push_back(&inv_dst_scale);
kern_coffset.c_offset_param.begin +=
sizeof(dst_scale) + sizeof(inv_dst_scale);
} else {
params.push_back(&relu);
kern_coffset.c_offset_param.begin += sizeof(relu);
}
params.push_back(&kern_param);
kern_coffset.c_offset_param.begin += sizeof(kern_param);
kern_coffset.c_offset_param.begin +=
sizeof(kern_coffset.c_offset_param);
kern_coffset.c_offset_param.max += kern_coffset.c_offset_param.begin;
params.push_back(&kern_coffset);
cucheck(cuLaunchKernel(kernel, gridx, gridy, 1, tx, ty, 1, 0, stream,
params.data(), 0));
} else {
std::vector<void*> params = {&src_ptr, &filter_ptr, &bias_ptr,
&dst_ptr, &alpha, &beta};

kern_coffset.c_offset_param.begin +=
sizeof(src_ptr) + sizeof(filter_ptr) + sizeof(bias_ptr) +
sizeof(dst_ptr) + sizeof(alpha) + sizeof(beta);

uint32_t relu = param.nonlineMode == NonlineMode::RELU ? 1 : 0;
if (param.nonlineMode == NonlineMode::H_SWISH) {
params.push_back(&dst_scale);
params.push_back(&inv_dst_scale);
kern_coffset.c_offset_param.begin +=
sizeof(dst_scale) + sizeof(inv_dst_scale);
} else {
params.push_back(&relu);
kern_coffset.c_offset_param.begin += sizeof(relu);
}
params.push_back(&kern_param);
kern_coffset.c_offset_param.begin += sizeof(kern_param);
kern_coffset.c_offset_param.begin +=
sizeof(kern_coffset.c_offset_param);
kern_coffset.c_offset_param.max += kern_coffset.c_offset_param.begin;
params.push_back(&kern_coffset);
cucheck(cuLaunchKernel(kernel, gridx, gridy, 1, tx, ty, 1, 0, stream,
params.data(), 0));
}
after_kernel_launch();
#endif
}

size_t ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::
get_preprocess_workspace_in_bytes(const SizeArgs& args) const {
return 0_z;
}

SmallVector<TensorLayout> ConvBiasForwardImpl::
AlgoSASSInt4NCHW64IMMAImplicitGemm::deduce_preprocessed_filter_layout(
const SizeArgs& args) const {
return {args.filter_layout->collapse_contiguous(),
args.bias_layout->collapse_contiguous()};
}

void ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::exec_preprocess(
const ExecArgs& args) const {
using Format = Param::Format;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout),
param);
auto&& stream = cuda_stream(args.opr->handle());
reorder_imma_filter_bias<4, 64>(
args.preprocessed_filter->tensors[0].compatible_ptr<int8_t>(),
args.preprocessed_filter->tensors[1].compatible_ptr<int32_t>(),
args.filter_tensor->compatible_ptr<int8_t>(),
args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw,
stream);
}

// vim: syntax=cpp.doxygen

Loading…
Cancel
Save