GitOrigin-RevId: 75d579635a
release-1.5
@@ -161,7 +161,7 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS4::get_workspace_bundle( | |||
ws_size_underlying_algo, ws_size_z}}; | |||
} | |||
return WorkspaceBundle{raw_ptr, | |||
{ws_size_src, ws_size_filter, | |||
ws_size_underlying_algo, ws_size_dst}}; | |||
{ws_size_src, ws_size_filter, ws_size_dst, | |||
ws_size_underlying_algo}}; | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -30,7 +30,10 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4 || | |||
param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL || | |||
param().mode == | |||
Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT, | |||
Param::Mode:: | |||
NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT || | |||
param().mode == Param::Mode::NCHW_NCHW64 || | |||
param().mode == Param::Mode::NCHW64_NCHW, | |||
"relayout format of cuda only support NCHW4->CHWN4 or " | |||
"CHWN4->NCHW4 or NCHW->NCHW4"); | |||
if ((param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || | |||
@@ -26,6 +26,9 @@ inline void get_scale_zeropoint(const DType& tensor_dtype, float& scale, | |||
scale = tensor_dtype.param<dtype::QuantizedS8>().scale; | |||
} else if (tensor_dtype.enumv() == DTypeEnum::QuantizedS4) { | |||
scale = tensor_dtype.param<dtype::QuantizedS4>().scale; | |||
} else if (tensor_dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||
zero_point = tensor_dtype.param<dtype::Quantized4Asymm>().zero_point; | |||
scale = tensor_dtype.param<dtype::Quantized4Asymm>().scale; | |||
} | |||
} | |||
@@ -41,8 +44,6 @@ void relayout_format::RelayoutFormatFast::exec(const TensorND& src, | |||
cudaStream_t stream, | |||
RelayoutFormat::Param::Mode mode, | |||
int group) { | |||
auto&& stype = src.layout.dtype; | |||
auto&& dtype = dst.layout.dtype; | |||
float src_scale = 1.f; | |||
float dst_scale = 1.f; | |||
uint8_t src_zero_point = 0; | |||
@@ -538,9 +538,9 @@ struct Translayout<64, 8, SrcType, dtype::QuantizedS4, dtype::QuantizedS4, | |||
}; | |||
#undef pack | |||
#define pack(_idx) \ | |||
((uint8_t)(post_process(intermediate[0][_idx])) | \ | |||
((uint8_t)(post_process(intermediate[1][_idx])) << 4)) | |||
#define pack(_idx) \ | |||
((post_process(intermediate[0][_idx]) & 0xf) | \ | |||
(post_process(intermediate[1][_idx]) << 4)) | |||
template <typename SrcType, bool same_scale> | |||
struct Translayout<64, 2, SrcType, dtype::QuantizedS4, dtype::QuantizedS4, | |||
same_scale> { | |||
@@ -648,9 +648,9 @@ struct Translayout<64, 8, SrcType, dtype::Quantized4Asymm, | |||
}; | |||
#undef pack | |||
#define pack(_idx) \ | |||
((uint8_t)(post_process(intermediate[0][_idx])) | \ | |||
((uint8_t)(post_process(intermediate[1][_idx])) << 4)) | |||
#define pack(_idx) \ | |||
(post_process(intermediate[0][_idx]) | \ | |||
(post_process(intermediate[1][_idx]) << 4)) | |||
template <typename SrcType, bool same_scale> | |||
struct Translayout<64, 2, SrcType, dtype::Quantized4Asymm, | |||
dtype::Quantized4Asymm, same_scale> { | |||
@@ -820,13 +820,25 @@ __global__ void kern_nchw_nchwx( | |||
int n_stride_src, int ic_stride, int n_stride_dst, int oc_stride, | |||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale> post_process, | |||
const char zero_point, const int group, const int ocpg) { | |||
static constexpr int size_src_type = sizeof(SrcType); | |||
static constexpr int size_dst_type = sizeof(DstType); | |||
#ifndef MEGDNN_COMMA | |||
#define MEGDNN_COMMA , | |||
#endif | |||
MEGDNN_STATIC_ASSERT(std::is_same<SrcType MEGDNN_COMMA DstType>::value, | |||
"Currently this kernel only support accessing tensor " | |||
"src and dst in same data type."); | |||
n_stride_src /= size_src_type; | |||
ic_stride /= size_src_type; | |||
n_stride_dst /= size_dst_type; | |||
oc_stride /= size_dst_type; | |||
const int n_idx = blockIdx.y; | |||
const int ihw_block_idx = blockIdx.x * blockDim.x + threadIdx.x; | |||
const int ihw_offset = | |||
ihw_block_idx * pack_w; | |||
const int ihw_offset_in_type = | |||
ihw_offset * size_nbits / (8 * sizeof(SrcType)); | |||
ihw_offset * size_nbits / (8 * size_src_type); | |||
if (ihw_offset < ihw) { | |||
const int src_offset_base = n_idx * n_stride_src + ihw_offset_in_type; | |||
const int dst_offset_base = | |||
@@ -836,7 +848,7 @@ __global__ void kern_nchw_nchwx( | |||
const int ic_block = icpg / pack_c; | |||
const int remain_ic = icpg % pack_c; | |||
const int src_group_stride = icpg * ic_stride; | |||
const int dst_group_stride = ocpg * oc_stride; | |||
const int dst_group_stride = (ocpg / pack_c) * oc_stride; | |||
for (int g_idx = 0; g_idx < group; ++g_idx) { | |||
const int src_offset = | |||
src_offset_base + g_idx * src_group_stride; | |||
@@ -1018,7 +1030,7 @@ public: | |||
int chan_stride_in_elements_, | |||
int channel_) | |||
: pointer{pointer_}, | |||
chan_stride_in_elements{chan_stride_in_elements}, | |||
chan_stride_in_elements{chan_stride_in_elements_}, | |||
channel{channel_} {} | |||
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag) { | |||
@@ -1031,7 +1043,7 @@ public: | |||
int frag_idx = i / pack_size * | |||
(lane_size_in_type / pack_size_in_type) + | |||
j; | |||
bool guard = i >= channel; | |||
bool guard = i < channel; | |||
cutlass::arch::global_load<AccessType, pack_size_in_byte>( | |||
frag_ptr[frag_idx], | |||
reinterpret_cast<void*>(pointer_ + | |||
@@ -1052,7 +1064,7 @@ public: | |||
int frag_idx = i / pack_size * | |||
(lane_size_in_type / pack_size_in_type) + | |||
j; | |||
bool guard = i >= channel; | |||
bool guard = i < channel; | |||
cutlass::arch::global_store<AccessType, pack_size_in_byte>( | |||
frag_ptr[frag_idx], | |||
reinterpret_cast<void*>(pointer_ + | |||
@@ -1092,11 +1104,24 @@ __global__ void kern_nchwx_nchw( | |||
size_nbits>; | |||
using Transpose = Translayout<pack_c, pack_w, SrcType, DnnSrcType, | |||
DnnDstType, same_scale>; | |||
static constexpr int size_src_type = sizeof(SrcType); | |||
static constexpr int size_dst_type = sizeof(DstType); | |||
MEGDNN_STATIC_ASSERT(std::is_same<SrcType MEGDNN_COMMA DstType>::value, | |||
"Currently this kernel only support accessing tensor " | |||
"src and dst in same data type."); | |||
n_stride_src /= size_src_type; | |||
ic_stride /= size_src_type; | |||
n_stride_dst /= size_dst_type; | |||
oc_stride /= size_dst_type; | |||
#undef MEGDNN_COMMA | |||
const int n_idx = blockIdx.y; | |||
const int ihw_block_idx = blockIdx.x * blockDim.x + threadIdx.x; | |||
const int ihw_offset = ihw_block_idx * pack_w; | |||
const int ihw_offset_in_type = | |||
ihw_offset * size_nbits / (8 * sizeof(SrcType)); | |||
ihw_offset * size_nbits / (8 * size_src_type); | |||
const int oc_stride_inner_dtype = | |||
oc_stride * size_dst_type / sizeof(InnerDtype); | |||
if (ihw_offset < ihw) { | |||
const int ic_block = (ic + pack_c - 1) / pack_c; | |||
const int src_offset_base = | |||
@@ -1105,8 +1130,8 @@ __global__ void kern_nchwx_nchw( | |||
SrcIterator src_iterator{const_cast<SrcType*>(src + src_offset_base), | |||
ic_stride, ic}; | |||
DstIteraotr dst_iterator{ | |||
reinterpret_cast<InnerDtype*>(dst + dst_offset_base), oc_stride, | |||
ic}; | |||
reinterpret_cast<InnerDtype*>(dst + dst_offset_base), | |||
oc_stride_inner_dtype, ic}; | |||
for (int ic_blk_idx = 0; ic_blk_idx < ic_block; ++ic_blk_idx) { | |||
typename SrcIterator::Fragment src_frag; | |||
@@ -1143,12 +1168,13 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( | |||
DEF(64, Quantized4Asymm, Quantized4Asymm) | |||
DEF(4, QuantizedS8, QuantizedS8) | |||
DEF(4, Uint8, QuantizedS8) | |||
DEF(4, Quantized8Asymm, Quantized8Asymm) | |||
DEF(4, QuantizedS32, QuantizedS32); | |||
DEF(4, Quantized8Asymm, QuantizedS8) | |||
DEF(4, QuantizedS32, QuantizedS32) | |||
// clang-format on | |||
megdnn_assert(pack_oc == 4 || pack_oc == 64, | |||
"Unsupport pack size(pack_oc:%d)", pack_oc); | |||
#undef DEF | |||
"Unsupport pack size(pack_oc:%d, src:%s, dst:%s)", pack_oc, | |||
stype.name(), dtype.name()); | |||
#undef DEF | |||
const int in_n = src.layout[0]; | |||
const int out_n = dst.layout[0]; | |||
const int ic = src.layout[1]; | |||
@@ -1157,6 +1183,7 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( | |||
const int oc = dst.layout[1] * pack_oc; | |||
const int hw = h * w; | |||
const int ocpg = oc / group; | |||
// stride in byte | |||
const int n_stride_src = src_layout.dtype.size(src_layout.stride[0]); | |||
const int ic_stride = src_layout.dtype.size(src_layout.stride[1]); | |||
const int n_stride_dst = dst_layout.dtype.size(dst_layout.stride[0]); | |||
@@ -1244,20 +1271,20 @@ void relayout_format::relayout_format_cuda_nchwx_nchw( | |||
auto& src_layout = src.layout; | |||
auto& dst_layout = dst.layout; | |||
// check pack size | |||
int pack_oc = std::numeric_limits<int>::min(); | |||
#define DEF(_pack_oc, _src_type, _dst_type) \ | |||
int pack_ic = std::numeric_limits<int>::min(); | |||
#define DEF(_pack_ic, _src_type, _dst_type) \ | |||
if (stype.enumv().ev == DTypeEnum::Ev::_src_type && \ | |||
dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ | |||
pack_oc = _pack_oc; \ | |||
pack_ic = _pack_ic; \ | |||
} | |||
// clang-format off | |||
DEF(64, QuantizedS4, QuantizedS4) | |||
DEF(64, Quantized4Asymm, Quantized4Asymm) | |||
// clang-format on | |||
megdnn_assert(pack_oc == 64, "Unsupport pack size(pack_oc:%d)", pack_oc); | |||
megdnn_assert(pack_ic == 64, "Unsupport pack size(pack_ic:%d)", pack_ic); | |||
#undef DEF | |||
const int n = src.layout[0]; | |||
const int c = src.layout[1]; | |||
const int c = src.layout[1] * pack_ic; | |||
const int h = src.layout[2]; | |||
// align to byte | |||
const int w = src.layout[3]; | |||
@@ -1266,7 +1293,7 @@ void relayout_format::relayout_format_cuda_nchwx_nchw( | |||
const int ic_stride = src_layout.dtype.size(src_layout.stride[1]); | |||
const int n_stride_dst = dst_layout.dtype.size(dst_layout.stride[0]); | |||
const int oc_stride = dst_layout.dtype.size(dst_layout.stride[1]); | |||
bool same_scale = src_scale == dst_scale; | |||
#define DISPATCH_RAW(_same_scale, _pack_w, _pack_oc, _src_type, _dst_type, \ | |||
_src_c_type, _dst_c_type, _size_nbits) \ | |||
@@ -378,7 +378,9 @@ MEGDNN_DEVICE __forceinline__ static float4 operator+(float4 lval, | |||
MEGDNN_DEVICE __forceinline__ static int transform_int8_to_int4x8( | |||
int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { | |||
unsigned out; | |||
#if __CUDA_ARCH__ >= 750 | |||
#if __CUDA_ARCH__ >= 750 && \ | |||
((__CUDACC_VER_MAJOR__ > 10) || \ | |||
((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) | |||
asm volatile( | |||
"{ .reg .u32 r4;" | |||
"cvt.pack.sat.s4.s32.b32 r4, %8, %7, 0;" | |||
@@ -411,7 +413,9 @@ MEGDNN_DEVICE __forceinline__ static int transform_int8_to_int4x8( | |||
MEGDNN_DEVICE __forceinline__ static int transform_int8_to_uint4x8( | |||
int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { | |||
unsigned out; | |||
#if __CUDA_ARCH__ >= 750 | |||
#if __CUDA_ARCH__ >= 750 && \ | |||
((__CUDACC_VER_MAJOR__ > 10) || \ | |||
((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) | |||
asm volatile( | |||
"{ .reg .u32 r4;" | |||
"cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;" | |||
@@ -226,6 +226,7 @@ void do_copy_diff_q8_q8(const TensorND& dst, const TensorND& src) { | |||
++isrc; | |||
} | |||
} | |||
void do_copy_diff_q32_q32(const TensorND& dst, const TensorND& src) { | |||
auto isrc = tensor_iter_valonly<DTypeTrait<dtype::QuantizedS32>::ctype>(src) | |||
.begin(); | |||
@@ -253,6 +254,38 @@ void do_copy_diff_u8_q8(const TensorND& dst, const TensorND& src) { | |||
} | |||
} | |||
void do_copy_diff_q4_q4(const TensorND& dst, const TensorND& src) { | |||
auto isrc = | |||
tensor_iter_valonly<DTypeTrait<dtype::QuantizedS4>::ctype>(src) | |||
.begin(); | |||
auto idst = | |||
tensor_iter_valonly<DTypeTrait<dtype::QuantizedS4>::ctype>(dst) | |||
.begin(); | |||
auto src_dt_parm = src.layout.dtype.param<dtype::QuantizedS4>(); | |||
auto dst_dt_parm = dst.layout.dtype.param<dtype::QuantizedS4>(); | |||
for (size_t i = 0, it = dst.layout.total_nr_elems(); i < it; ++i) { | |||
*idst = dst_dt_parm.quantize(src_dt_parm.dequantize(int8_t(*isrc))); | |||
++idst; | |||
++isrc; | |||
} | |||
} | |||
void do_copy_diff_qu4_qu4(const TensorND& dst, const TensorND& src) { | |||
auto isrc = | |||
tensor_iter_valonly<DTypeTrait<dtype::Quantized4Asymm>::ctype>(src) | |||
.begin(); | |||
auto idst = | |||
tensor_iter_valonly<DTypeTrait<dtype::Quantized4Asymm>::ctype>(dst) | |||
.begin(); | |||
auto src_dt_parm = src.layout.dtype.param<dtype::Quantized4Asymm>(); | |||
auto dst_dt_parm = dst.layout.dtype.param<dtype::Quantized4Asymm>(); | |||
for (size_t i = 0, it = dst.layout.total_nr_elems(); i < it; ++i) { | |||
*idst = dst_dt_parm.quantize(src_dt_parm.dequantize(uint8_t(*isrc))); | |||
++idst; | |||
++isrc; | |||
} | |||
} | |||
void check_layout_and_canonize(TensorLayout& src, TensorLayout& dst) { | |||
megdnn_assert(dst.is_non_overlapping_strong()); | |||
src = src.collapse_contiguous(); | |||
@@ -595,6 +628,24 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
}; | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); | |||
return; | |||
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4 && | |||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS4) { | |||
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; | |||
check_layout_and_canonize(src0.layout, src0.layout); | |||
auto func = [](const TensorND& dst, const TensorND& src) { | |||
do_copy_diff_q4_q4(dst, src); | |||
}; | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); | |||
return; | |||
} else if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm && | |||
dst.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; | |||
check_layout_and_canonize(src0.layout, src0.layout); | |||
auto func = [](const TensorND& dst, const TensorND& src) { | |||
do_copy_diff_qu4_qu4(dst, src); | |||
}; | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); | |||
return; | |||
} else { | |||
m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); | |||
} | |||
@@ -237,6 +237,89 @@ TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW4_IC_SMALL) { | |||
.execs({{8, 3, 768, 1280}, {}}); | |||
} | |||
TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW64) { | |||
Checker<RelayoutFormat> checker(handle_cuda()); | |||
UniformIntRNG s4{-8, 7}; | |||
UniformIntRNG u4{0, 15}; | |||
param::RelayoutFormat param; | |||
param.mode = param::RelayoutFormat::Mode::NCHW_NCHW64; | |||
for (size_t n : {1, 3}) { | |||
for (size_t c : {64, 128}) { | |||
for (size_t h : {7, 14, 16, 28}) { | |||
for (size_t w : {2, 4, 14, 16}) { | |||
checker.set_dtype(0, dtype::QuantizedS4{2.f}) | |||
.set_dtype(1, dtype::QuantizedS4{2.f}) | |||
.set_rng(0, &s4) | |||
.set_param(param) | |||
.execs({{n, c, h, w}, {}}); | |||
checker.set_dtype(0, dtype::Quantized4Asymm{1.2f, 8}) | |||
.set_dtype(1, dtype::Quantized4Asymm{1.2f, 4}) | |||
.set_rng(0, &u4) | |||
.set_param(param) | |||
.execs({{n, c, h, w}, {}}); | |||
checker.set_dtype(0, dtype::QuantizedS4{1.19990307f}) | |||
.set_dtype(1, dtype::QuantizedS4{1.f}) | |||
.set_rng(0, &s4) | |||
.set_param(param) | |||
.execs({{n, c, h, w}, {}}); | |||
checker.set_dtype(0, dtype::Quantized4Asymm{1.19990307f, 8}) | |||
.set_dtype(1, dtype::Quantized4Asymm{1.f, 4}) | |||
.set_rng(0, &u4) | |||
.set_param(param) | |||
.set_epsilon(1e-3) | |||
.execs({{n, c, h, w}, {}}); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
TEST_F(CUDA, RELAYOUT_FORMAT_NCHW64_NCHW) { | |||
Checker<RelayoutFormat> checker(handle_cuda()); | |||
UniformIntRNG s4{-8, 7}; | |||
UniformIntRNG u4{0, 15}; | |||
param::RelayoutFormat param; | |||
param.mode = param::RelayoutFormat::Mode::NCHW64_NCHW; | |||
for (size_t n : {1, 3}) { | |||
for (size_t c : {64, 128}) { | |||
for (size_t h : {7, 14, 16, 28}) { | |||
for (size_t w : {2, 4, 14, 16}) { | |||
checker.set_dtype(0, dtype::QuantizedS4{2.f}) | |||
.set_dtype(1, dtype::QuantizedS4{2.f}) | |||
.set_rng(0, &s4) | |||
.set_param(param) | |||
.set_epsilon(1e-3) | |||
.execs({{n, c / 64, h, w, 64}, {}}); | |||
checker.set_dtype(0, dtype::Quantized4Asymm{1.2f, 4}) | |||
.set_dtype(1, dtype::Quantized4Asymm{1.2f, 8}) | |||
.set_rng(0, &u4) | |||
.set_param(param) | |||
.set_epsilon(1e-3) | |||
.execs({{n, c / 64, h, w, 64}, {}}); | |||
checker.set_dtype(0, dtype::QuantizedS4{1.19990307f}) | |||
.set_dtype(1, dtype::QuantizedS4{1.f}) | |||
.set_rng(0, &s4) | |||
.set_param(param) | |||
.set_epsilon(1e-3) | |||
.execs({{n, c / 64, h, w, 64}, {}}); | |||
checker.set_dtype(0, dtype::Quantized4Asymm{1.20211209f, 8}) | |||
.set_dtype(1, dtype::Quantized4Asymm{1.f, 4}) | |||
.set_rng(0, &u4) | |||
.set_param(param) | |||
.set_epsilon(1e-3) | |||
.execs({{n, c / 64, h, w, 64}, {}}); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
#if MEGDNN_WITH_BENCHMARK | |||
TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT) { | |||
using Param = RelayoutFormat::Param; | |||