Browse Source

feat(dnn/cuda): support transforming layout between nchw and nchw64 when channel not aligned to 64

GitOrigin-RevId: e9ecbcf2e2
release-1.5
Megvii Engine Team 4 years ago
parent
commit
4fe68ac9ed
6 changed files with 375 additions and 112 deletions
  1. +14
    -8
      dnn/src/common/relayout_format.cpp
  2. +149
    -0
      dnn/src/cuda/relayout_format/helper.cuh
  3. +64
    -50
      dnn/src/cuda/relayout_format/relayout_format.cu
  4. +7
    -7
      dnn/src/cuda/utils.cuh
  5. +119
    -35
      dnn/src/naive/relayout_format/opr_impl.cpp
  6. +22
    -12
      dnn/test/cuda/relayout_format.cpp

+ 14
- 8
dnn/src/common/relayout_format.cpp View File

@@ -252,10 +252,10 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src,
megdnn_assert(dst[1] % param().group == 0); megdnn_assert(dst[1] % param().group == 0);
break; break;
case Param::Mode::NCHW_NCHW64: case Param::Mode::NCHW_NCHW64:
megdnn_assert(src.ndim == 4 && (src[1] % 64) == 0);
megdnn_assert(src.ndim == 4);
dst.ndim = 5; dst.ndim = 5;
dst[0] = src[0]; dst[0] = src[0];
dst[1] = src[1] / 64;
dst[1] = div_ceil(src[1], 64_z);
dst[2] = src[2]; dst[2] = src[2];
dst[3] = src[3]; dst[3] = src[3];
dst[4] = 64; dst[4] = 64;
@@ -264,7 +264,7 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src,
megdnn_assert(src.ndim == 5); megdnn_assert(src.ndim == 5);
dst.ndim = 4; dst.ndim = 4;
dst[0] = src[0]; dst[0] = src[0];
dst[1] = src[1] * 64;
dst[1] = param().oc == 0 ? src[1] * 64 : param().oc;
dst[2] = src[2]; dst[2] = src[2];
dst[3] = src[3]; dst[3] = src[3];
break; break;
@@ -483,12 +483,11 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src,
case Param::Mode::NCHW4_NCHW: case Param::Mode::NCHW4_NCHW:
// nchw to nchw4 // nchw to nchw4
{ {
megdnn_assert(src.format == dst.format);
exec_workspace = exec_workspace =
TensorLayout({src[0], src[1] * 4, src[2], src[3]}, TensorLayout({src[0], src[1] * 4, src[2], src[3]},
src.dtype, src.format)
.reshape({src[0], src[1], 4, src[2], src[3]})
.dimshuffle({0, 1, 3, 4, 2});
exec_src = src;
dst.dtype, dst.format);
exec_src = src.dimshuffle({0, 1, 4, 2, 3});
exec_dst = dst; exec_dst = dst;
} }
break; break;
@@ -658,13 +657,20 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src,
case Param::Mode::NCHW_NCHW64: case Param::Mode::NCHW_NCHW64:
// src is {N, C, H, W} // src is {N, C, H, W}
// dst is {N, C/64, H, W, 64} // dst is {N, C/64, H, W, 64}
exec_src = src.reshape({src[0], src[1] / 64, 64, src[2], src[3]})
exec_workspace = TensorLayout(
{src[0], round_up(src[1], 64_z), src[2], src[3]},
src.dtype);
exec_src = exec_workspace
.reshape({src[0], div_ceil(src[1], 64_z), 64,
src[2], src[3]})
.dimshuffle({0, 1, 3, 4, 2}); .dimshuffle({0, 1, 3, 4, 2});
exec_dst = dst; exec_dst = dst;
break; break;
case Param::Mode::NCHW64_NCHW: case Param::Mode::NCHW64_NCHW:
// src is {N, C/64, H, W, 64} // src is {N, C/64, H, W, 64}
// dst is {N, C, H, W} // dst is {N, C, H, W}
exec_workspace = TensorLayout({src[0], src[1] * 64, src[2], src[3]},
dst.dtype);
exec_src = src.dimshuffle({0, 1, 4, 2, 3}); exec_src = src.dimshuffle({0, 1, 4, 2, 3});
exec_dst = dst; exec_dst = dst;
break; break;


+ 149
- 0
dnn/src/cuda/relayout_format/helper.cuh View File

@@ -0,0 +1,149 @@
/**
* \file dnn/src/cuda/relayout_format/helper.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

namespace megdnn {
namespace cuda {
namespace relayout_format {

#define devfunc __forceinline__ __device__
template <int size_nbits>
devfunc int make_zero(int zero_point);

template <>
devfunc int make_zero<4>(int zero_point) {
return transform_int8_to_uint4x8(zero_point, zero_point, zero_point,
zero_point, zero_point, zero_point,
zero_point, zero_point);
}

template <typename AccessType, int LoadBytes>
struct global_load_with_zero_point;

/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Specializations
//
/////////////////////////////////////////////////////////////////////////////////////////////////

/////////////////////////////////////////////////////////////////////////////////////////////////

// The redundant mov PTX instruction is used to enforce the compiler to
// initialize data to zero before ld.global
template <typename AccessType>
struct global_load_with_zero_point<AccessType, 32> {
devfunc global_load_with_zero_point(AccessType& D, void const* ptr,
bool pred_guard, int zero_point) {
uint4* data = reinterpret_cast<uint4*>(&D);

asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %9, 0;\n"
" mov.b32 %0, %10;\n"
" mov.b32 %1, %10;\n"
" mov.b32 %2, %10;\n"
" mov.b32 %3, %10;\n"
" mov.b32 %4, %10;\n"
" mov.b32 %5, %10;\n"
" mov.b32 %6, %10;\n"
" mov.b32 %7, %10;\n"
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n"
" @p ld.global.v4.u32 {%4, %5, %6, %7}, [%11];\n"
"}\n"
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z),
"=r"(data[0].w), "=r"(data[1].x), "=r"(data[1].y),
"=r"(data[1].z), "=r"(data[1].w)
: "l"(ptr), "r"((int)pred_guard),
"r"(reinterpret_cast<unsigned&>(zero_point)),
"l"(((uint8_t*)ptr) + 16));
}
};

template <typename AccessType>
struct global_load_with_zero_point<AccessType, 16> {
devfunc global_load_with_zero_point(AccessType& D, void const* ptr,
bool pred_guard, int zero_point) {
uint4& data = reinterpret_cast<uint4&>(D);

asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %5, 0;\n"
" mov.b32 %0, %6;\n"
" mov.b32 %1, %6;\n"
" mov.b32 %2, %6;\n"
" mov.b32 %3, %6;\n"
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n"
"}\n"
: "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
: "l"(ptr), "r"((int)pred_guard),
"r"(reinterpret_cast<unsigned&>(zero_point)));
}
};

template <typename AccessType>
struct global_load_with_zero_point<AccessType, 8> {
devfunc global_load_with_zero_point(AccessType& D, void const* ptr,
bool pred_guard, int zero_point) {
uint2& data = reinterpret_cast<uint2&>(D);

asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %3, 0;\n"
" mov.b32 %0, %4;\n"
" mov.b32 %1, %4;\n"
" @p ld.global.v2.u32 {%0, %1}, [%2];\n"
"}\n"
: "=r"(data.x), "=r"(data.y)
: "l"(ptr), "r"((int)pred_guard),
"r"(reinterpret_cast<unsigned&>(zero_point)));
}
};

template <typename AccessType>
struct global_load_with_zero_point<AccessType, 4> {
devfunc global_load_with_zero_point(AccessType& D, void const* ptr,
bool pred_guard, int zero_point) {
unsigned& data = reinterpret_cast<unsigned&>(D);

asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %2, 0;\n"
" mov.b32 %0, %3;\n"
" @p ld.global.u32 %0, [%1];\n"
"}\n"
: "=r"(data)
: "l"(ptr), "r"((int)pred_guard),
"r"(reinterpret_cast<unsigned&>(zero_point)));
}
};

template <typename AccessType>
struct global_load_with_zero_point<AccessType, 1> {
devfunc global_load_with_zero_point(AccessType& D, void const* ptr,
bool pred_guard, int zero_point) {
if (pred_guard)
D = *(reinterpret_cast<AccessType const*>(ptr));
else {
unsigned uv = reinterpret_cast<unsigned&>(zero_point);
uint8_t& data = reinterpret_cast<uint8_t&>(D);
data = uv & 0xff;
}
}
};

#undef devfunc
} // namespace relayout_format
} // namespace cuda
} // namespace megdnn

+ 64
- 50
dnn/src/cuda/relayout_format/relayout_format.cu View File

@@ -18,6 +18,7 @@
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
#include "src/cuda/query_blocksize.cuh" #include "src/cuda/query_blocksize.cuh"
#include "src/cuda/relayout_format/relayout_format.cuh" #include "src/cuda/relayout_format/relayout_format.cuh"
#include "src/cuda/relayout_format/helper.cuh"
using namespace megdnn; using namespace megdnn;
using namespace cuda; using namespace cuda;


@@ -728,17 +729,18 @@ struct Translayout<64, 2, SrcType, dtype::Quantized4Asymm,
#undef pack #undef pack


template <typename DstType> template <typename DstType>
inline __device__ DstType make_zero_pad(const char zero_point) {
inline __device__ DstType make_zero_pad(const uint8_t zero_point) {
return zero_point; return zero_point;
} }


template <> template <>
inline __device__ char4 make_zero_pad<char4>(const char zero_point) {
return {zero_point, zero_point, zero_point, zero_point};
inline __device__ char4 make_zero_pad<char4>(const uint8_t zero_point) {
char izp = reinterpret_cast<const char&>(zero_point);
return {izp, izp, izp, izp};
} }


template <> template <>
inline __device__ int4 make_zero_pad<int4>(const char zero_point) {
inline __device__ int4 make_zero_pad<int4>(const uint8_t zero_point) {
return {zero_point, zero_point, zero_point, zero_point}; return {zero_point, zero_point, zero_point, zero_point};
} }


@@ -767,7 +769,7 @@ inline __device__ void write_helper<array_wrapper<uint8_t, 32>>(
: "l"(ptr_), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), : "l"(ptr_), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z),
"r"(data[0].w), "l"(((uint8_t*)ptr_) + 16), "r"(data[1].x), "r"(data[0].w), "l"(((uint8_t*)ptr_) + 16), "r"(data[1].x),
"r"(data[1].y), "r"(data[1].z), "r"(data[1].w)); "r"(data[1].y), "r"(data[1].z), "r"(data[1].w));
};
}


template <bool with_pad, int pack_w, int pack_c, bool same_scale, bool all_pad, template <bool with_pad, int pack_w, int pack_c, bool same_scale, bool all_pad,
typename SrcType, typename DstType, typename DnnSrcType, typename SrcType, typename DstType, typename DnnSrcType,
@@ -825,7 +827,7 @@ struct RelayoutKern {
const SrcType* src, DstType* dst, const int ic_stride, const SrcType* src, DstType* dst, const int ic_stride,
const int remain_ic, const int remain_ic,
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process,
const char zero_point) {
const uint8_t zero_point) {
InnerDtype read_channel[pack_c]; InnerDtype read_channel[pack_c];
if (all_pad) { if (all_pad) {
const InnerDtype zero_pad = make_zero_pad<InnerDtype>(zero_point); const InnerDtype zero_pad = make_zero_pad<InnerDtype>(zero_point);
@@ -855,7 +857,7 @@ __global__ void kern_nchw_nchwx(
const SrcType* src, DstType* dst, int in_n, int ic, int ihw, const SrcType* src, DstType* dst, int in_n, int ic, int ihw,
int n_stride_src, int ic_stride, int n_stride_dst, int oc_stride, int n_stride_src, int ic_stride, int n_stride_dst, int oc_stride,
CudaPostProcess<DnnSrcType, DnnDstType, same_scale> post_process, CudaPostProcess<DnnSrcType, DnnDstType, same_scale> post_process,
const char zero_point, const int group, const int ocpg) {
const uint8_t zero_point, const int group, const int ocpg) {
static constexpr int size_src_type = sizeof(SrcType); static constexpr int size_src_type = sizeof(SrcType);
static constexpr int size_dst_type = sizeof(DstType); static constexpr int size_dst_type = sizeof(DstType);
#ifndef MEGDNN_COMMA #ifndef MEGDNN_COMMA
@@ -1072,6 +1074,7 @@ public:
MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) {
pointer += (c_idx / pack_size) * chan_stride_in_elements + pointer += (c_idx / pack_size) * chan_stride_in_elements +
hw_idx * pack_size * size_nbits / (8 * sizeof(Type)); hw_idx * pack_size * size_nbits / (8 * sizeof(Type));
channel -= c_idx;
} }


MEGDNN_DEVICE __forceinline__ void add_pointer_offset( MEGDNN_DEVICE __forceinline__ void add_pointer_offset(
@@ -1079,7 +1082,7 @@ public:
pointer += offset_in_type; pointer += offset_in_type;
} }


MEGDNN_DEVICE __forceinline__ void load(Fragment& frag) {
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) {
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
Type* pointer_ = pointer; Type* pointer_ = pointer;
#pragma unroll #pragma unroll
@@ -1090,11 +1093,12 @@ public:
(lane_size_in_type / pack_size_in_type) + (lane_size_in_type / pack_size_in_type) +
j; j;
bool guard = i < channel; bool guard = i < channel;
cutlass::arch::global_load<AccessType, pack_size_in_byte>(
relayout_format::global_load_with_zero_point<AccessType,
pack_size_in_byte>(
frag_ptr[frag_idx], frag_ptr[frag_idx],
reinterpret_cast<void*>(pointer_ + reinterpret_cast<void*>(pointer_ +
j * pack_size_in_type), j * pack_size_in_type),
guard);
guard, zero_point);
} }
pointer_ += chan_stride_in_elements; pointer_ += chan_stride_in_elements;
} }
@@ -1173,6 +1177,7 @@ public:


MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) {
pointer += (c_idx / pack_size) * chan_stride_in_elements; pointer += (c_idx / pack_size) * chan_stride_in_elements;
channel -= c_idx;
#pragma unroll #pragma unroll
for (int i = 0; i < mask_size; ++i) { for (int i = 0; i < mask_size; ++i) {
mask[i] = 0; mask[i] = 0;
@@ -1201,7 +1206,7 @@ public:
pointer += offset_in_type; pointer += offset_in_type;
} }


MEGDNN_DEVICE __forceinline__ void load(Fragment& frag) {
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) {
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
Type* pointer_ = pointer; Type* pointer_ = pointer;
#pragma unroll #pragma unroll
@@ -1214,9 +1219,11 @@ public:
int mask_index = (frag_idx >> 5); int mask_index = (frag_idx >> 5);
int mask_shift = (frag_idx & 0x1f); int mask_shift = (frag_idx & 0x1f);
bool guard = (mask[mask_index] & (1 << mask_shift)); bool guard = (mask[mask_index] & (1 << mask_shift));
cutlass::arch::global_load<AccessType, pack_size_in_byte>(
relayout_format::global_load_with_zero_point<AccessType,
pack_size_in_byte>(
frag_ptr[frag_idx], frag_ptr[frag_idx],
reinterpret_cast<void*>(pointer_ + stride[j]), guard);
reinterpret_cast<void*>(pointer_ + stride[j]), guard,
zero_point);
} }
pointer_ += chan_stride_in_elements; pointer_ += chan_stride_in_elements;
} }
@@ -1306,11 +1313,13 @@ struct RelayoutProblem {
int batch_size; int batch_size;
int channels; int channels;
int hw; int hw;
int zero_point;
MEGDNN_HOST MEGDNN_DEVICE Param(SrcIterator src_iterator_, MEGDNN_HOST MEGDNN_DEVICE Param(SrcIterator src_iterator_,
DstIterator dst_iterator_, DstIterator dst_iterator_,
CudaPostProcess post_process_, CudaPostProcess post_process_,
int n_stride_src_, int n_stride_dst_, int n_stride_src_, int n_stride_dst_,
int batch_size_, int channels_, int hw_)
int batch_size_, int channels_, int hw_,
int zero_point_)
: src_iterator{src_iterator_}, : src_iterator{src_iterator_},
dst_iterator{dst_iterator_}, dst_iterator{dst_iterator_},
post_process{post_process_}, post_process{post_process_},
@@ -1318,7 +1327,8 @@ struct RelayoutProblem {
n_stride_dst{n_stride_dst_}, n_stride_dst{n_stride_dst_},
batch_size{batch_size_}, batch_size{batch_size_},
channels{channels_}, channels{channels_},
hw{hw_} {}
hw{hw_},
zero_point{zero_point_} {}
}; };
}; };


@@ -1345,7 +1355,9 @@ __global__ void relayout_kern(typename RelayoutProblem_::Param param) {
param.dst_iterator.initialize(c_idx, hw_idx); param.dst_iterator.initialize(c_idx, hw_idx);
typename SrcIterator::Fragment src_frag; typename SrcIterator::Fragment src_frag;
typename DstIterator::Fragment dst_frag; typename DstIterator::Fragment dst_frag;
param.src_iterator.load(src_frag);
int zp = relayout_format::make_zero<SrcIterator::size_nbits>(
param.zero_point);
param.src_iterator.load(src_frag, zp);
RelayoutProblem_::Transpose::trans( RelayoutProblem_::Transpose::trans(
reinterpret_cast<typename SrcIterator::Fragment&>(dst_frag), reinterpret_cast<typename SrcIterator::Fragment&>(dst_frag),
src_frag, param.post_process); src_frag, param.post_process);
@@ -1382,7 +1394,8 @@ void relayout_format::relayout_format_cuda_nchw_nchwx(
stype.name(), dtype.name()); stype.name(), dtype.name());
#undef DEF #undef DEF
// no padding // no padding
if (src.layout.stride[2] == static_cast<ptrdiff_t>(src.layout[3])) {
if (stype.enumv().ev != DTypeEnum::Ev::QuantizedS4 &&
stype.enumv().ev != DTypeEnum::Ev::Quantized4Asymm) {
const int in_n = src.layout[0]; const int in_n = src.layout[0];
const int out_n = dst.layout[0]; const int out_n = dst.layout[0];
const int ic = src.layout[1]; const int ic = src.layout[1];
@@ -1428,18 +1441,10 @@ void relayout_format::relayout_format_cuda_nchw_nchwx(
DISPATCH_RAW(false, 4, 4, _src_type, _dst_type, char, char, 8); \ DISPATCH_RAW(false, 4, 4, _src_type, _dst_type, char, char, 8); \
DISPATCH_RAW(true, 1, 4, _src_type, _dst_type, char, char, 8); \ DISPATCH_RAW(true, 1, 4, _src_type, _dst_type, char, char, 8); \
DISPATCH_RAW(false, 1, 4, _src_type, _dst_type, char, char, 8); DISPATCH_RAW(false, 1, 4, _src_type, _dst_type, char, char, 8);
#define DISPATCH_4BITS(_src_type, _dst_type) \
DISPATCH_RAW(true, 8, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, 8, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(true, 2, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, 2, 64, _src_type, _dst_type, char, char, 4);
DISPATCH_INT(QuantizedS32, QuantizedS32); DISPATCH_INT(QuantizedS32, QuantizedS32);
DISPATCH_BYTE(Uint8, QuantizedS8); DISPATCH_BYTE(Uint8, QuantizedS8);
DISPATCH_BYTE(Quantized8Asymm, QuantizedS8); DISPATCH_BYTE(Quantized8Asymm, QuantizedS8);
DISPATCH_BYTE(QuantizedS8, QuantizedS8); DISPATCH_BYTE(QuantizedS8, QuantizedS8);
DISPATCH_4BITS(QuantizedS4, QuantizedS4);
DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm);
#undef DISPATCH_4BITS
#undef DISPATCH_BYTE #undef DISPATCH_BYTE
#undef DISPATCH_INT #undef DISPATCH_INT
#undef DISPATCH_RAW #undef DISPATCH_RAW
@@ -1450,7 +1455,8 @@ void relayout_format::relayout_format_cuda_nchw_nchwx(
} else { } else {
megdnn_assert(src_layout.dtype.is_low_bit()); megdnn_assert(src_layout.dtype.is_low_bit());
int n = src.layout[0]; int n = src.layout[0];
int c = src.layout[1];
int ic = src.layout[1];
int oc = dst.layout[1] * 64;
int h = src.layout[2]; int h = src.layout[2];
// align to byte // align to byte
int w = src.layout[3]; int w = src.layout[3];
@@ -1460,12 +1466,13 @@ void relayout_format::relayout_format_cuda_nchw_nchwx(
int ic_stride = src_layout.stride[1]; int ic_stride = src_layout.stride[1];
int n_stride_dst = dst_layout.stride[0]; int n_stride_dst = dst_layout.stride[0];
int oc_stride = dst_layout.stride[1]; int oc_stride = dst_layout.stride[1];
int problem_size = n * (c / pack_oc) * hw;
int problem_size = n * (oc / pack_oc) * hw;
bool same_scale = src_scale == dst_scale; bool same_scale = src_scale == dst_scale;
#define DISPATCH_RAW(_same_scale, _pack_w, _pack_oc, _src_type, _dst_type, \
_src_c_type, _dst_c_type, _size_nbits) \
if (same_scale == _same_scale && hw % _pack_w == 0 && \
stype.enumv().ev == DTypeEnum::Ev::_src_type && \
bool padding = w % 2 != 0;
#define DISPATCH_RAW(_padding, _same_scale, _pack_w, _pack_oc, _src_type, \
_dst_type, _src_c_type, _dst_c_type, _size_nbits) \
if (padding == _padding && same_scale == _same_scale && \
hw % _pack_w == 0 && stype.enumv().ev == DTypeEnum::Ev::_src_type && \
dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \
using InnerDtype_ = typename DTypeRWHelper< \ using InnerDtype_ = typename DTypeRWHelper< \
typename DTypeTrait<dtype::_src_type>::ctype, \ typename DTypeTrait<dtype::_src_type>::ctype, \
@@ -1473,8 +1480,10 @@ void relayout_format::relayout_format_cuda_nchw_nchwx(
using SrcIterator_ = \ using SrcIterator_ = \
TensorIteratorOverChannel<InnerDtype_, 1, _pack_oc, _pack_w, \ TensorIteratorOverChannel<InnerDtype_, 1, _pack_oc, _pack_w, \
_size_nbits>; \ _size_nbits>; \
using DstIterator_ = MaskedTensorIteratorOverChannel< \
_dst_c_type, _pack_oc, _pack_oc, _pack_w, _size_nbits>; \
using DstIterator_ = \
typename TensorIteratorPolicy<_padding, _dst_c_type, _pack_oc, \
_pack_oc, _pack_w, \
_size_nbits>::TensorIterator; \
using CudaPostProcess_ = \ using CudaPostProcess_ = \
CudaPostProcess<dtype::_src_type, dtype::_dst_type, \ CudaPostProcess<dtype::_src_type, dtype::_dst_type, \
_same_scale>; \ _same_scale>; \
@@ -1489,17 +1498,18 @@ void relayout_format::relayout_format_cuda_nchw_nchwx(
n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(_dst_c_type)); \ n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(_dst_c_type)); \
oc_stride = oc_stride * _size_nbits / (8 * sizeof(_dst_c_type)); \ oc_stride = oc_stride * _size_nbits / (8 * sizeof(_dst_c_type)); \
typename RelayoutProblem_::Param param{ \ typename RelayoutProblem_::Param param{ \
SrcIterator_{(InnerDtype_*)src.raw_ptr, ic_stride, c, w, \
SrcIterator_{(InnerDtype_*)src.raw_ptr, ic_stride, ic, w, \
w_pad}, \ w_pad}, \
DstIterator_{(_dst_c_type*)dst.raw_ptr, oc_stride, c, w, \
DstIterator_{(_dst_c_type*)dst.raw_ptr, oc_stride, oc, w, \
w_pad}, \ w_pad}, \
CudaPostProcess_{src_scale, src_zero_point, dst_scale, \ CudaPostProcess_{src_scale, src_zero_point, dst_scale, \
dst_zero_point}, \ dst_zero_point}, \
n_stride_src, \ n_stride_src, \
n_stride_dst, \ n_stride_dst, \
n, \ n, \
c, \
hw}; \
oc, \
hw, \
src_zero_point}; \
auto kernel = relayout_kern<RelayoutProblem_>; \ auto kernel = relayout_kern<RelayoutProblem_>; \
int nr_threads = query_blocksize_for_kernel(kernel); \ int nr_threads = query_blocksize_for_kernel(kernel); \
nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \ nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \
@@ -1507,11 +1517,15 @@ void relayout_format::relayout_format_cuda_nchw_nchwx(
const dim3 thread_dim(nr_threads); \ const dim3 thread_dim(nr_threads); \
return kernel<<<block_dim, thread_dim, 0, stream>>>(param); \ return kernel<<<block_dim, thread_dim, 0, stream>>>(param); \
} }
#define DISPATCH_4BITS(_src_type, _dst_type) \
DISPATCH_RAW(true, 8, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, 8, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(true, 2, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, 2, 64, _src_type, _dst_type, char, char, 4);
#define DISPATCH_4BITS(_src_type, _dst_type) \
DISPATCH_RAW(true, true, 8, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(true, false, 8, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(true, true, 2, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(true, false, 2, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, true, 8, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, false, 8, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, true, 2, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, false, 2, 64, _src_type, _dst_type, char, char, 4);
DISPATCH_4BITS(QuantizedS4, QuantizedS4); DISPATCH_4BITS(QuantizedS4, QuantizedS4);
DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm);
#undef DISPATCH_4BITS #undef DISPATCH_4BITS
@@ -1521,7 +1535,6 @@ void relayout_format::relayout_format_cuda_nchw_nchwx(
"Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", "Unsupported data type(src:%s, dst:%s) or image size(%dx%d).",
stype.name(), dtype.name(), h, w); stype.name(), dtype.name(), h, w);
} }
after_kernel_launch();
} }


bool relayout_format::relayout_format_cuda_usable( bool relayout_format::relayout_format_cuda_usable(
@@ -1568,7 +1581,7 @@ void relayout_format::relayout_format_cuda_nchwx_nchw(
megdnn_assert(pack_ic == 64, "Unsupport pack size(pack_ic:%d)", pack_ic); megdnn_assert(pack_ic == 64, "Unsupport pack size(pack_ic:%d)", pack_ic);
#undef DEF #undef DEF
int n = src.layout[0]; int n = src.layout[0];
int c = src.layout[1] * pack_ic;
int ic = src.layout[1] * pack_ic;
int h = src.layout[2]; int h = src.layout[2];
// align to byte // align to byte
int w = src.layout[3]; int w = src.layout[3];
@@ -1578,7 +1591,8 @@ void relayout_format::relayout_format_cuda_nchwx_nchw(
int ic_stride = src_layout.stride[1]; int ic_stride = src_layout.stride[1];
int n_stride_dst = dst_layout.stride[0]; int n_stride_dst = dst_layout.stride[0];
int oc_stride = dst_layout.stride[1]; int oc_stride = dst_layout.stride[1];
int problem_size = n * (c / pack_ic) * hw;
int problem_size = n * (ic / pack_ic) * hw;
int oc = dst.layout[1];


bool same_scale = src_scale == dst_scale; bool same_scale = src_scale == dst_scale;
bool padding = w % 2 != 0; bool padding = w % 2 != 0;
@@ -1611,17 +1625,18 @@ void relayout_format::relayout_format_cuda_nchwx_nchw(
n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(InnerDtype_)); \ n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(InnerDtype_)); \
oc_stride = oc_stride * _size_nbits / (8 * sizeof(InnerDtype_)); \ oc_stride = oc_stride * _size_nbits / (8 * sizeof(InnerDtype_)); \
typename RelayoutProblem_::Param param{ \ typename RelayoutProblem_::Param param{ \
SrcIterator_{(_src_c_type*)src.raw_ptr, ic_stride, c, w, \
SrcIterator_{(_src_c_type*)src.raw_ptr, ic_stride, ic, w, \
w_pad}, \ w_pad}, \
DstIterator_{(InnerDtype_*)dst.raw_ptr, oc_stride, c, w, \
DstIterator_{(InnerDtype_*)dst.raw_ptr, oc_stride, oc, w, \
w_pad}, \ w_pad}, \
CudaPostProcess_{src_scale, src_zero_point, dst_scale, \ CudaPostProcess_{src_scale, src_zero_point, dst_scale, \
dst_zero_point}, \ dst_zero_point}, \
n_stride_src, \ n_stride_src, \
n_stride_dst, \ n_stride_dst, \
n, \ n, \
c, \
hw}; \
ic, \
hw, \
src_zero_point}; \
auto kernel = relayout_kern<RelayoutProblem_>; \ auto kernel = relayout_kern<RelayoutProblem_>; \
int nr_threads = query_blocksize_for_kernel(kernel); \ int nr_threads = query_blocksize_for_kernel(kernel); \
nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \ nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \
@@ -1645,7 +1660,6 @@ void relayout_format::relayout_format_cuda_nchwx_nchw(
megdnn_assert(false, megdnn_assert(false,
"Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", "Unsupported data type(src:%s, dst:%s) or image size(%dx%d).",
stype.name(), dtype.name(), h, w); stype.name(), dtype.name(), h, w);
after_kernel_launch();
} }


void relayout_format::relayout_format_cuda_nchw4_nchw( void relayout_format::relayout_format_cuda_nchw4_nchw(


+ 7
- 7
dnn/src/cuda/utils.cuh View File

@@ -21,6 +21,7 @@
#include "cuda.h" #include "cuda.h"
#include "src/cuda/cudnn_with_check.h" #include "src/cuda/cudnn_with_check.h"
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cutlass/platform/platform.h"


#define cuda_check(_x) \ #define cuda_check(_x) \
do { \ do { \
@@ -448,13 +449,12 @@ MEGDNN_DEVICE __forceinline__ static int transform_int8_to_uint4x8(
template <bool signedness, typename T> template <bool signedness, typename T>
MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(T storage, MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(T storage,
int bits) { int bits) {
uint8_t result = (uint8_t)((storage >> bits) & 0xf);
if (signedness) {
static constexpr uint8_t mask = (uint8_t)((1 << 4) - 1);
return (result & uint8_t(1 << 3)) ? ((int)(result) | ~(int)(mask))
: (int)(result);
}
return int(result);
static constexpr int shift = 28;
using type = typename cutlass::platform::conditional<signedness, int,
unsigned>::type;
unsigned intermediate = static_cast<unsigned>(storage);
type result = reinterpret_cast<type&>(intermediate);
return (result << (shift - bits)) >> shift;
} }


MEGDNN_DEVICE __forceinline__ static void transform_int4x8_to_int8( MEGDNN_DEVICE __forceinline__ static void transform_int4x8_to_int8(


+ 119
- 35
dnn/src/naive/relayout_format/opr_impl.cpp View File

@@ -42,6 +42,36 @@ void recursive_cp(const TensorND& dst, const TensorND& src, size_t idx = 0,
} }
} }


template <size_t size_nbits>
void lowbit_recursive_cp(const TensorND& dst, const TensorND& src,
size_t idx = 0, size_t src_offset = 0,
size_t dst_offset = 0) {
MEGDNN_STATIC_ASSERT(!(8_z % size_nbits),
"size in bits of lowbit data type can only be 1, 2, 4 "
"or 8");
if (idx < (src.layout.ndim - 1)) {
for (size_t i = 0; i < src.layout[idx]; ++i) {
lowbit_recursive_cp<size_nbits>(
dst, src, idx + 1, src_offset + i * src.layout.stride[idx],
dst_offset + i * dst.layout.stride[idx]);
}
} else {
megdnn_assert(src.layout.stride[idx] == 1);
megdnn_assert(dst.layout.stride[idx] == 1);
size_t dim_bytes = div_ceil(src.layout[idx], 8_z / size_nbits);
// offset in elements
uint8_t* dptr = reinterpret_cast<uint8_t*>(dst.raw_ptr) +
(dst_offset * size_nbits / 8);
uint8_t* sptr = reinterpret_cast<uint8_t*>(src.raw_ptr) +
(src_offset * size_nbits / 8);
for (size_t i = 0; i < dim_bytes; ++i) {
*dptr = *sptr;
dptr++;
sptr++;
}
}
}

void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src) { void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src) {
switch (src.layout.dtype.enumv()) { switch (src.layout.dtype.enumv()) {
#define cb(name, ctype) \ #define cb(name, ctype) \
@@ -54,10 +84,17 @@ void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src) {
cb(Int32, dt_int32); cb(Int32, dt_int32);
cb(QuantizedS32, dt_int32); cb(QuantizedS32, dt_int32);
cb(QuantizedS8, dt_qint8); cb(QuantizedS8, dt_qint8);

#undef cb
#define cb(name, size_nbits) \
case (DTypeEnum::name): { \
lowbit_recursive_cp<size_nbits>(dst, src); \
break; \
}
cb(QuantizedS4, 4);
cb(Quantized4Asymm, 4);
#undef cb
default: default:
megdnn_assert(0, "not support dtype %s", src.layout.dtype.name()); megdnn_assert(0, "not support dtype %s", src.layout.dtype.name());
#undef cb
} }
} }


@@ -66,24 +103,27 @@ void extract_from_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src,
megdnn_assert(dst.layout.is_contiguous() && src.layout.is_contiguous(), megdnn_assert(dst.layout.is_contiguous() && src.layout.is_contiguous(),
"dst %s, src %s", dst.layout.to_string().c_str(), "dst %s, src %s", dst.layout.to_string().c_str(),
src.layout.to_string().c_str()); src.layout.to_string().c_str());
const size_t type_size = dst.layout.dtype.size();
const size_t n = dst.layout[0]; const size_t n = dst.layout[0];
const size_t n_stride_dst = dst.layout.stride[0];
const size_t n_stride_src = src.layout.stride[0];
const size_t n_stride_dst_in_bytes =
dst.layout.dtype.size(dst.layout.stride[0]);
const size_t n_stride_src_in_bytes =
src.layout.dtype.size(src.layout.stride[0]);
const size_t ocpg = dst.layout[1] / group; const size_t ocpg = dst.layout[1] / group;
const size_t icpg = src.layout[1] / group; const size_t icpg = src.layout[1] / group;
const size_t dst_hw = dst.layout[2] * dst.layout[3];
const size_t src_hw = src.layout[2] * src.layout[3];
megdnn_assert(dst_hw == src_hw);
const size_t dst_c_stride_in_bytes =
dst.layout.dtype.size(dst.layout.stride[1]);
const size_t src_c_stride_in_bytes =
src.layout.dtype.size(src.layout.stride[1]);
megdnn_assert(dst_c_stride_in_bytes == src_c_stride_in_bytes);
for (size_t nid = 0; nid < n; ++nid) { for (size_t nid = 0; nid < n; ++nid) {
const size_t n_offset_dst = nid * n_stride_dst * type_size;
const size_t n_offset_src = nid * n_stride_src * type_size;
const size_t n_offset_dst = nid * n_stride_dst_in_bytes;
const size_t n_offset_src = nid * n_stride_src_in_bytes;
for (size_t gid = 0; gid < group; ++gid) { for (size_t gid = 0; gid < group; ++gid) {
memcpy((char*)dst.raw_ptr + n_offset_dst + memcpy((char*)dst.raw_ptr + n_offset_dst +
gid * ocpg * dst_hw * type_size,
gid * ocpg * dst_c_stride_in_bytes,
(char*)src.raw_ptr + n_offset_src + (char*)src.raw_ptr + n_offset_src +
gid * icpg * src_hw * type_size,
ocpg * dst_hw * type_size);
gid * icpg * src_c_stride_in_bytes,
ocpg * dst_c_stride_in_bytes);
} }
} }
}; };
@@ -415,6 +455,30 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src,
return oc * ic * h * w * src.dtype.size(); return oc * ic * h * w * src.dtype.size();
} }


case Param::Mode::NCHW_NCHW64: {
if (src[1] % 64 != 0) {
size_t n = src[0];
size_t c = round_up(src[1], 64_z);
size_t h = src[2];
size_t w = src[3];
TensorLayout wsly({n, c, h, w}, src.dtype);
return wsly.span().dist_byte();
}
return 0_z;
}

case Param::Mode::NCHW64_NCHW: {
if (param().oc != 0) {
size_t n = src[0];
size_t c = src[1] * 64;
size_t h = src[2];
size_t w = src[3];
TensorLayout wsly({n, c, h, w}, dst.dtype);
return wsly.span().dist_byte();
}
return 0_z;
}

default: default:
return 0; return 0;
} }
@@ -437,6 +501,7 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
// clean dst // clean dst
MEGDNN_DISPATCH_CPU_KERN( MEGDNN_DISPATCH_CPU_KERN(
m_handle, memset(dst.raw_ptr, 0, dst.layout.span().dist_byte())); m_handle, memset(dst.raw_ptr, 0, dst.layout.span().dist_byte()));
// pre
if (param().mode == Param::Mode::NCHW_NHWCD4I) { if (param().mode == Param::Mode::NCHW_NHWCD4I) {
size_t N = src.layout[0]; size_t N = src.layout[0];
size_t IC = src.layout[1]; size_t IC = src.layout[1];
@@ -551,6 +616,27 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
cb2(2, 4, NCHW_NCHW4, group_src_layout, workspace_layout); cb2(2, 4, NCHW_NCHW4, group_src_layout, workspace_layout);


} }
} else if (param().mode == Param::Mode::NCHW_NCHW64) {
MIDOUT_BEGIN(megdnn_naive_relayout_format,
midout_iv(Param::Mode::NCHW_NCHW64)) {
size_t c = src.layout[1];
if (c % 64 != 0) {
uint8_t zp = 0;
if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
zp = src.layout.dtype.param<dtype::Quantized4Asymm>()
.zero_point;
zp = (zp & 0xf) | (zp << 4);
}
MEGDNN_DISPATCH_CPU_KERN(
m_handle, memset(workspace.raw_ptr, zp,
exec_workspace.span().dist_byte()));
TensorND ws_nd(workspace.raw_ptr, exec_workspace);
MEGDNN_DISPATCH_CPU_KERN(m_handle,
padding_to_workspace(ws_nd, src););
exec_src_nd.raw_ptr = workspace.raw_ptr;
}
}
MIDOUT_END();
} else if (param().mode == } else if (param().mode ==
Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) { Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) {
cb(1, 4, NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT); cb(1, 4, NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT);
@@ -574,24 +660,16 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
cb(1, 2, 4, NCHW_NCHW4_WEIGHT); cb(1, 2, 4, NCHW_NCHW4_WEIGHT);
} }
} else if (param().mode == Param::Mode::NCHW4_NCHW) { } else if (param().mode == Param::Mode::NCHW4_NCHW) {
if (exec_workspace.total_nr_elems() == dst.layout.total_nr_elems()) {
m_handle->relayout_opr()->exec(
exec_src_nd, {dst.raw_ptr, exec_workspace}, handle());
return;
} else {
m_handle->relayout_opr()->exec(
exec_src_nd, {workspace.raw_ptr, exec_workspace}, handle());
TensorLayout workspace_layout{{src.layout[0], src.layout[1] * 4,
src.layout[2], src.layout[3]},
src.layout.dtype,
src.layout.format};
extract_from_workspace(exec_dst_nd,
{workspace.raw_ptr, workspace_layout},
param().group);
return;
if (exec_workspace.total_nr_elems() != dst.layout.total_nr_elems()) {
exec_dst_nd = {workspace.raw_ptr, exec_workspace};
}
} else if (param().mode == Param::Mode::NCHW64_NCHW) {
if (exec_workspace.total_nr_elems() != dst.layout.total_nr_elems()) {
exec_dst_nd = {workspace.raw_ptr, exec_workspace};
} }
} }

// do relayout
if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm &&
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; TensorND src0 = exec_src_nd, dst0 = exec_dst_nd;
@@ -600,7 +678,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
do_copy_diff_qu8_q8(dst, src); do_copy_diff_qu8_q8(dst, src);
}; };
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0));
return;
} else if (src.layout.dtype.enumv() == DTypeEnum::Uint8 && } else if (src.layout.dtype.enumv() == DTypeEnum::Uint8 &&
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; TensorND src0 = exec_src_nd, dst0 = exec_dst_nd;
@@ -609,7 +686,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
do_copy_diff_u8_q8(dst, src); do_copy_diff_u8_q8(dst, src);
}; };
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0));
return;
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 && } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 &&
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; TensorND src0 = exec_src_nd, dst0 = exec_dst_nd;
@@ -618,7 +694,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
do_copy_diff_q8_q8(dst, src); do_copy_diff_q8_q8(dst, src);
}; };
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0));
return;
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS32 && } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS32 &&
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS32) { dst.layout.dtype.enumv() == DTypeEnum::QuantizedS32) {
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; TensorND src0 = exec_src_nd, dst0 = exec_dst_nd;
@@ -627,7 +702,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
do_copy_diff_q32_q32(dst, src); do_copy_diff_q32_q32(dst, src);
}; };
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0));
return;
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4 && } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4 &&
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS4) { dst.layout.dtype.enumv() == DTypeEnum::QuantizedS4) {
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; TensorND src0 = exec_src_nd, dst0 = exec_dst_nd;
@@ -636,7 +710,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
do_copy_diff_q4_q4(dst, src); do_copy_diff_q4_q4(dst, src);
}; };
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0));
return;
} else if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm && } else if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm &&
dst.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { dst.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; TensorND src0 = exec_src_nd, dst0 = exec_dst_nd;
@@ -645,9 +718,20 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
do_copy_diff_qu4_qu4(dst, src); do_copy_diff_qu4_qu4(dst, src);
}; };
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0));
return;
} else { } else {
m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle());
}

// post
if (param().mode == Param::Mode::NCHW4_NCHW ||
param().mode == Param::Mode::NCHW64_NCHW) {
if (exec_workspace.total_nr_elems() != dst.layout.total_nr_elems()) {
megdnn_assert(exec_workspace.dtype == dst.layout.dtype);
TensorND ws_nd{workspace.raw_ptr, exec_workspace};
MEGDNN_DISPATCH_CPU_KERN(
m_handle,
extract_from_workspace(dst, ws_nd, param().group););
}
} }
#undef cb #undef cb
} }


+ 22
- 12
dnn/test/cuda/relayout_format.cpp View File

@@ -18,7 +18,6 @@


using namespace megdnn; using namespace megdnn;
using namespace test; using namespace test;
#define MEGDNN_WITH_BENCHMARK 1


TEST_F(CUDA, RELAYOUT_FORMAT) { TEST_F(CUDA, RELAYOUT_FORMAT) {
Checker<RelayoutFormat> checker(handle_cuda()); Checker<RelayoutFormat> checker(handle_cuda());
@@ -245,7 +244,7 @@ TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW64) {
param::RelayoutFormat param; param::RelayoutFormat param;
param.mode = param::RelayoutFormat::Mode::NCHW_NCHW64; param.mode = param::RelayoutFormat::Mode::NCHW_NCHW64;
for (size_t n : {1, 3}) { for (size_t n : {1, 3}) {
for (size_t c : {64, 128}) {
for (size_t c : {15, 64, 128}) {
for (size_t h : {7, 14, 16, 28}) { for (size_t h : {7, 14, 16, 28}) {
for (size_t w : {2, 3, 7, 8, 16, 31}) { for (size_t w : {2, 3, 7, 8, 16, 31}) {
checker.set_dtype(0, dtype::QuantizedS4{2.f}) checker.set_dtype(0, dtype::QuantizedS4{2.f})
@@ -285,36 +284,41 @@ TEST_F(CUDA, RELAYOUT_FORMAT_NCHW64_NCHW) {
param::RelayoutFormat param; param::RelayoutFormat param;
param.mode = param::RelayoutFormat::Mode::NCHW64_NCHW; param.mode = param::RelayoutFormat::Mode::NCHW64_NCHW;
for (size_t n : {1, 3}) { for (size_t n : {1, 3}) {
for (size_t c : {64, 128}) {
for (size_t c : {15, 64, 128}) {
for (size_t h : {7, 14, 16, 28}) { for (size_t h : {7, 14, 16, 28}) {
for (size_t w : {2, 3, 4, 7, 14, 16, 17}) { for (size_t w : {2, 3, 4, 7, 14, 16, 17}) {
if (c % 64 != 0) {
param.oc = c;
} else {
param.oc = 0;
}
checker.set_dtype(0, dtype::QuantizedS4{2.f}) checker.set_dtype(0, dtype::QuantizedS4{2.f})
.set_dtype(1, dtype::QuantizedS4{2.f}) .set_dtype(1, dtype::QuantizedS4{2.f})
.set_rng(0, &s4) .set_rng(0, &s4)
.set_param(param) .set_param(param)
.set_epsilon(1e-3) .set_epsilon(1e-3)
.execs({{n, c / 64, h, w, 64}, {}});
.execs({{n, (c + 63) / 64, h, w, 64}, {}});


checker.set_dtype(0, dtype::Quantized4Asymm{1.2f, 4}) checker.set_dtype(0, dtype::Quantized4Asymm{1.2f, 4})
.set_dtype(1, dtype::Quantized4Asymm{1.2f, 8}) .set_dtype(1, dtype::Quantized4Asymm{1.2f, 8})
.set_rng(0, &u4) .set_rng(0, &u4)
.set_param(param) .set_param(param)
.set_epsilon(1e-3) .set_epsilon(1e-3)
.execs({{n, c / 64, h, w, 64}, {}});
.execs({{n, (c + 63) / 64, h, w, 64}, {}});


checker.set_dtype(0, dtype::QuantizedS4{1.19990307f}) checker.set_dtype(0, dtype::QuantizedS4{1.19990307f})
.set_dtype(1, dtype::QuantizedS4{1.f}) .set_dtype(1, dtype::QuantizedS4{1.f})
.set_rng(0, &s4) .set_rng(0, &s4)
.set_param(param) .set_param(param)
.set_epsilon(1e-3) .set_epsilon(1e-3)
.execs({{n, c / 64, h, w, 64}, {}});
.execs({{n, (c + 63) / 64, h, w, 64}, {}});


checker.set_dtype(0, dtype::Quantized4Asymm{1.20211209f, 8}) checker.set_dtype(0, dtype::Quantized4Asymm{1.20211209f, 8})
.set_dtype(1, dtype::Quantized4Asymm{1.f, 4}) .set_dtype(1, dtype::Quantized4Asymm{1.f, 4})
.set_rng(0, &u4) .set_rng(0, &u4)
.set_param(param) .set_param(param)
.set_epsilon(1e-3) .set_epsilon(1e-3)
.execs({{n, c / 64, h, w, 64}, {}});
.execs({{n, (c + 63) / 64, h, w, 64}, {}});
} }
} }
} }
@@ -375,10 +379,14 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) {
CUBenchmarker<RelayoutFormat> benchmarker(handle_cuda()); CUBenchmarker<RelayoutFormat> benchmarker(handle_cuda());
benchmarker.set_param(param); benchmarker.set_param(param);
benchmarker.set_dtype(0, dtype::QuantizedS4{1.19990307f}) benchmarker.set_dtype(0, dtype::QuantizedS4{1.19990307f})
.set_dtype(1, dtype::QuantizedS4{1.20210322f});
.set_dtype(1, dtype::QuantizedS4{1.19990307f});


for (auto&& shape : shapes) { for (auto&& shape : shapes) {
double memaccess = double(shape.total_nr_elems()) * 1e-6;
double memaccess =
double(TensorLayout(shape, dtype::QuantizedS4{1.f})
.span()
.dist_byte()) *
2e-6;
auto time_ms = benchmarker.execs({shape, {}}); auto time_ms = benchmarker.execs({shape, {}});
printf("execute %s, time %.4f ms, %.4f GB/s\n", printf("execute %s, time %.4f ms, %.4f GB/s\n",
shape.to_string().c_str(), time_ms, memaccess / time_ms); shape.to_string().c_str(), time_ms, memaccess / time_ms);
@@ -387,8 +395,9 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) {


{ {
TensorShapeArray shapes = { TensorShapeArray shapes = {
{1, 64, 56, 56}, {16, 64, 56, 56}, {64, 64, 56, 56},
{1, 64, 56, 55}, {16, 64, 56, 55}, {64, 64, 56, 55},
{1, 64, 56, 56}, {16, 64, 56, 56}, {64, 64, 56, 56},
{1, 64, 56, 55}, {16, 64, 56, 55}, {64, 64, 56, 55},
{1, 256, 384, 640},
}; };
Param param; Param param;
param.mode = param::RelayoutFormat::Mode::NCHW_NCHW64; param.mode = param::RelayoutFormat::Mode::NCHW_NCHW64;
@@ -399,7 +408,8 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) {
{64, 1, 56, 56, 64}, {64, 1, 56, 56, 64},
{1, 32, 7, 7, 64}, {1, 32, 7, 7, 64},
{16, 32, 7, 7, 64}, {16, 32, 7, 7, 64},
{64, 32, 7, 7, 64},
{64, 32, 7, 7, 64},
{1, 4, 384, 640, 64},
}; };
Param param; Param param;
param.mode = param::RelayoutFormat::Mode::NCHW64_NCHW; param.mode = param::RelayoutFormat::Mode::NCHW64_NCHW;


Loading…
Cancel
Save