|
|
@@ -10,10 +10,10 @@ |
|
|
|
* implied. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include <stdint.h> |
|
|
|
#pragma GCC diagnostic push |
|
|
|
#pragma GCC diagnostic ignored "-Wunused-parameter" |
|
|
|
#pragma GCC diagnostic ignored "-Wstrict-aliasing" |
|
|
|
#include "cutlass/fast_math.h" |
|
|
|
#include "cutlass/arch/memory.h" |
|
|
|
#pragma GCC diagnostic pop |
|
|
|
#include "src/cuda/query_blocksize.cuh" |
|
|
@@ -112,6 +112,8 @@ struct CudaPostProcess<dtype::QuantizedS32, dtype::QuantizedS32, true> { |
|
|
|
|
|
|
|
template <> |
|
|
|
struct CudaPostProcess<dtype::QuantizedS4, dtype::QuantizedS4, false> { |
|
|
|
using SrcType = dtype::QuantizedS4; |
|
|
|
using DstType = dtype::QuantizedS4; |
|
|
|
CudaDTypeParamImpl<dt_qint4> m_dst_type_cvt; |
|
|
|
CudaDTypeParamImpl<dt_qint4> m_src_type_cvt; |
|
|
|
CudaPostProcess(float src_scale, uint8_t, float dst_scale, uint8_t) { |
|
|
@@ -126,12 +128,16 @@ struct CudaPostProcess<dtype::QuantizedS4, dtype::QuantizedS4, false> { |
|
|
|
|
|
|
|
template <> |
|
|
|
struct CudaPostProcess<dtype::QuantizedS4, dtype::QuantizedS4, true> { |
|
|
|
using SrcType = dtype::QuantizedS4; |
|
|
|
using DstType = dtype::QuantizedS4; |
|
|
|
CudaPostProcess(float, uint8_t, float, uint8_t){}; |
|
|
|
inline __device__ int8_t operator()(int8_t val) { return val; } |
|
|
|
}; |
|
|
|
|
|
|
|
template <> |
|
|
|
struct CudaPostProcess<dtype::Quantized4Asymm, dtype::Quantized4Asymm, false> { |
|
|
|
using SrcType = dtype::Quantized4Asymm; |
|
|
|
using DstType = dtype::Quantized4Asymm; |
|
|
|
CudaDTypeParamImpl<dt_quint4> m_dst_type_cvt; |
|
|
|
CudaDTypeParamImpl<dt_quint4> m_src_type_cvt; |
|
|
|
CudaPostProcess(float src_scale, uint8_t src_zero_point, float dst_scale, |
|
|
@@ -149,6 +155,8 @@ struct CudaPostProcess<dtype::Quantized4Asymm, dtype::Quantized4Asymm, false> { |
|
|
|
|
|
|
|
template <> |
|
|
|
struct CudaPostProcess<dtype::Quantized4Asymm, dtype::Quantized4Asymm, true> { |
|
|
|
using SrcType = dtype::Quantized4Asymm; |
|
|
|
using DstType = dtype::Quantized4Asymm; |
|
|
|
uint8_t m_src_zero_point = 0; |
|
|
|
uint8_t m_dst_zero_point = 0; |
|
|
|
CudaPostProcess(float, uint8_t src_zero_point, float, |
|
|
@@ -328,13 +336,20 @@ struct Translayout<2, 64, SrcType, dtype::QuantizedS4, dtype::QuantizedS4, |
|
|
|
unpack_int4x2(6) |
|
|
|
unpack_int4x2(7) |
|
|
|
// clang-format on |
|
|
|
|
|
|
|
|
|
|
|
int frag_idx = i / 8; |
|
|
|
dst_frag[0 * 8 + frag_idx] = pack_channel(0); |
|
|
|
dst_frag[1 * 8 + frag_idx] = pack_channel(1); |
|
|
|
#undef unpack_int4x2 |
|
|
|
} |
|
|
|
} |
|
|
|
using Fragment = array_wrapper<SrcType, 64>; |
|
|
|
static inline __device__ void trans( |
|
|
|
Fragment& dst, Fragment& src, |
|
|
|
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process) { |
|
|
|
trans(reinterpret_cast<DstDtype(&)[2]>(dst), |
|
|
|
reinterpret_cast<InnerDtype(&)[64]>(src), post_process, 0); |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
template <typename SrcType, bool same_scale> |
|
|
@@ -375,6 +390,13 @@ struct Translayout<8, 64, SrcType, dtype::QuantizedS4, dtype::QuantizedS4, |
|
|
|
dst_frag[7 * 8 + frag_idx] = pack_channel(7); |
|
|
|
} |
|
|
|
} |
|
|
|
using Fragment = array_wrapper<unsigned, 64>; |
|
|
|
static inline __device__ void trans( |
|
|
|
Fragment& dst, Fragment& src, |
|
|
|
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process) { |
|
|
|
trans(reinterpret_cast<DstDtype(&)[8]>(dst), |
|
|
|
reinterpret_cast<InnerDtype(&)[64]>(src), post_process, 0); |
|
|
|
} |
|
|
|
}; |
|
|
|
#undef pack_channel |
|
|
|
|
|
|
@@ -428,6 +450,13 @@ struct Translayout<2, 64, SrcType, dtype::Quantized4Asymm, |
|
|
|
#undef unpack_int4x2 |
|
|
|
} |
|
|
|
} |
|
|
|
using Fragment = array_wrapper<SrcType, 64>; |
|
|
|
static inline __device__ void trans( |
|
|
|
Fragment& dst, Fragment& src, |
|
|
|
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process) { |
|
|
|
trans(reinterpret_cast<DstDtype(&)[2]>(dst), |
|
|
|
reinterpret_cast<InnerDtype(&)[64]>(src), post_process, 0); |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
template <typename SrcType, bool same_scale> |
|
|
@@ -468,6 +497,13 @@ struct Translayout<8, 64, SrcType, dtype::Quantized4Asymm, |
|
|
|
dst_frag[7 * 8 + frag_idx] = pack_channel(7); |
|
|
|
} |
|
|
|
} |
|
|
|
using Fragment = array_wrapper<unsigned, 64>; |
|
|
|
static inline __device__ void trans( |
|
|
|
Fragment& dst, Fragment& src, |
|
|
|
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process) { |
|
|
|
trans(reinterpret_cast<DstDtype(&)[8]>(dst), |
|
|
|
reinterpret_cast<InnerDtype(&)[64]>(src), post_process, 0); |
|
|
|
} |
|
|
|
}; |
|
|
|
#undef pack_channel |
|
|
|
|
|
|
@@ -1028,11 +1064,21 @@ public: |
|
|
|
: pointer{nullptr}, chan_stride_in_elements{0}, channel{0} {} |
|
|
|
MEGDNN_DEVICE TensorIteratorOverChannel(Type* pointer_, |
|
|
|
int chan_stride_in_elements_, |
|
|
|
int channel_) |
|
|
|
int channel_, int, int) |
|
|
|
: pointer{pointer_}, |
|
|
|
chan_stride_in_elements{chan_stride_in_elements_}, |
|
|
|
channel{channel_} {} |
|
|
|
|
|
|
|
MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { |
|
|
|
pointer += (c_idx / pack_size) * chan_stride_in_elements + |
|
|
|
hw_idx * pack_size * size_nbits / (8 * sizeof(Type)); |
|
|
|
} |
|
|
|
|
|
|
|
MEGDNN_DEVICE __forceinline__ void add_pointer_offset( |
|
|
|
size_t offset_in_type) { |
|
|
|
pointer += offset_in_type; |
|
|
|
} |
|
|
|
|
|
|
|
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag) { |
|
|
|
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); |
|
|
|
Type* pointer_ = pointer; |
|
|
@@ -1087,64 +1133,224 @@ private: |
|
|
|
int channel; |
|
|
|
}; |
|
|
|
|
|
|
|
template <int pack_w, int pack_c, bool same_scale, typename SrcType, |
|
|
|
typename DstType, typename DnnSrcType, typename DnnDstType, |
|
|
|
int size_nbits = 8> |
|
|
|
__global__ void kern_nchwx_nchw( |
|
|
|
const SrcType* src, DstType* dst, int ic, int ihw, int n_stride_src, |
|
|
|
int ic_stride, int n_stride_dst, int oc_stride, |
|
|
|
CudaPostProcess<DnnSrcType, DnnDstType, same_scale> post_process, |
|
|
|
const char zero_point) { |
|
|
|
using InnerDtype = |
|
|
|
typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, |
|
|
|
pack_w>::InnerDtype; |
|
|
|
using SrcIterator = TensorIteratorOverChannel<SrcType, pack_c, pack_c, |
|
|
|
pack_w, size_nbits>; |
|
|
|
using DstIteraotr = TensorIteratorOverChannel<InnerDtype, 1, pack_c, pack_w, |
|
|
|
size_nbits>; |
|
|
|
using Transpose = Translayout<pack_c, pack_w, SrcType, DnnSrcType, |
|
|
|
DnnDstType, same_scale>; |
|
|
|
static constexpr int size_src_type = sizeof(SrcType); |
|
|
|
static constexpr int size_dst_type = sizeof(DstType); |
|
|
|
MEGDNN_STATIC_ASSERT(std::is_same<SrcType MEGDNN_COMMA DstType>::value, |
|
|
|
"Currently this kernel only support accessing tensor " |
|
|
|
"src and dst in same data type."); |
|
|
|
n_stride_src /= size_src_type; |
|
|
|
ic_stride /= size_src_type; |
|
|
|
n_stride_dst /= size_dst_type; |
|
|
|
oc_stride /= size_dst_type; |
|
|
|
#undef MEGDNN_COMMA |
|
|
|
template <typename Type_, int pack_size_, int chan_blk_, int width_, |
|
|
|
int size_nbits_> |
|
|
|
class MaskedTensorIteratorOverChannel { |
|
|
|
public: |
|
|
|
using Type = Type_; |
|
|
|
static constexpr int pack_size = pack_size_; |
|
|
|
static constexpr int chan_blk = chan_blk_; |
|
|
|
static constexpr int width = width_; |
|
|
|
static constexpr int size_nbits = size_nbits_; |
|
|
|
static constexpr int elements_in_type = |
|
|
|
chan_blk * width * size_nbits / (8 * sizeof(Type)); |
|
|
|
static constexpr int lane_size_in_type = |
|
|
|
(width * pack_size * size_nbits) / (8 * sizeof(Type)); |
|
|
|
static constexpr int pack_size_in_type = |
|
|
|
(pack_size * size_nbits) >= (8 * sizeof(Type)) |
|
|
|
? (pack_size * size_nbits / (8 * sizeof(Type))) |
|
|
|
: (width * pack_size * size_nbits / (8 * sizeof(Type))); |
|
|
|
static constexpr int pack_size_in_byte = pack_size_in_type * sizeof(Type); |
|
|
|
static constexpr int accesses = elements_in_type / pack_size_in_type; |
|
|
|
static constexpr int mask_size = (accesses + 32 - 1) / 32; |
|
|
|
using AccessType = array_wrapper<Type, pack_size_in_type>; |
|
|
|
using Fragment = array_wrapper<Type, elements_in_type>; |
|
|
|
|
|
|
|
const int n_idx = blockIdx.y; |
|
|
|
const int ihw_block_idx = blockIdx.x * blockDim.x + threadIdx.x; |
|
|
|
const int ihw_offset = ihw_block_idx * pack_w; |
|
|
|
const int ihw_offset_in_type = |
|
|
|
ihw_offset * size_nbits / (8 * size_src_type); |
|
|
|
const int oc_stride_inner_dtype = |
|
|
|
oc_stride * size_dst_type / sizeof(InnerDtype); |
|
|
|
if (ihw_offset < ihw) { |
|
|
|
const int ic_block = (ic + pack_c - 1) / pack_c; |
|
|
|
const int src_offset_base = |
|
|
|
n_idx * n_stride_src + ihw_offset_in_type * pack_c; |
|
|
|
const int dst_offset_base = n_idx * n_stride_dst + ihw_offset_in_type; |
|
|
|
SrcIterator src_iterator{const_cast<SrcType*>(src + src_offset_base), |
|
|
|
ic_stride, ic}; |
|
|
|
DstIteraotr dst_iterator{ |
|
|
|
reinterpret_cast<InnerDtype*>(dst + dst_offset_base), |
|
|
|
oc_stride_inner_dtype, ic}; |
|
|
|
|
|
|
|
for (int ic_blk_idx = 0; ic_blk_idx < ic_block; ++ic_blk_idx) { |
|
|
|
typename SrcIterator::Fragment src_frag; |
|
|
|
typename DstIteraotr::Fragment dst_frag; |
|
|
|
src_iterator.load(src_frag); |
|
|
|
Transpose::trans( |
|
|
|
reinterpret_cast<typename SrcIterator::Fragment&>(dst_frag), |
|
|
|
src_frag, post_process); |
|
|
|
dst_iterator.store(dst_frag); |
|
|
|
src_iterator.advance(); |
|
|
|
dst_iterator.advance(); |
|
|
|
MEGDNN_HOST MEGDNN_DEVICE MaskedTensorIteratorOverChannel() |
|
|
|
: pointer{nullptr}, |
|
|
|
chan_stride_in_elements{0}, |
|
|
|
channel{0} {} |
|
|
|
MEGDNN_HOST MEGDNN_DEVICE MaskedTensorIteratorOverChannel( |
|
|
|
Type* pointer_, int chan_stride_in_elements_, int channel_, |
|
|
|
int bound_, int div_) |
|
|
|
: pointer{pointer_}, |
|
|
|
chan_stride_in_elements{chan_stride_in_elements_}, |
|
|
|
channel{channel_}, |
|
|
|
bound{bound_}, |
|
|
|
div{div_} { |
|
|
|
cutlass::find_divisor(mul, shr, div); |
|
|
|
} |
|
|
|
|
|
|
|
MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { |
|
|
|
pointer += (c_idx / pack_size) * chan_stride_in_elements; |
|
|
|
#pragma unroll |
|
|
|
for (int i = 0; i < mask_size; ++i) { |
|
|
|
mask[i] = 0; |
|
|
|
} |
|
|
|
#pragma unroll |
|
|
|
for (int i = 0; i < chan_blk; i += pack_size) { |
|
|
|
#pragma unroll |
|
|
|
for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { |
|
|
|
int offset = hw_idx + j; |
|
|
|
int h, w; |
|
|
|
cutlass::fast_divmod(h, w, offset, div, mul, shr); |
|
|
|
bool guard = (i < channel) && (w < bound); |
|
|
|
int index = (i / pack_size) * |
|
|
|
(lane_size_in_type / pack_size_in_type) + |
|
|
|
j; |
|
|
|
int mask_index = (index >> 5); |
|
|
|
int mask_shift = (index & 0x1f); |
|
|
|
mask[mask_index] |= (guard << mask_shift); |
|
|
|
stride[j] = (h * bound + w) * pack_size * size_nbits / |
|
|
|
(8 * sizeof(Type)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
MEGDNN_DEVICE __forceinline__ void add_pointer_offset(size_t offset_in_type) { |
|
|
|
pointer += offset_in_type; |
|
|
|
} |
|
|
|
|
|
|
|
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag) { |
|
|
|
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); |
|
|
|
Type* pointer_ = pointer; |
|
|
|
#pragma unroll |
|
|
|
for (int i = 0; i < chan_blk; i += pack_size) { |
|
|
|
#pragma unroll |
|
|
|
for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { |
|
|
|
int frag_idx = i / pack_size * |
|
|
|
(lane_size_in_type / pack_size_in_type) + |
|
|
|
j; |
|
|
|
int mask_index = (frag_idx >> 5); |
|
|
|
int mask_shift = (frag_idx & 0x1f); |
|
|
|
bool guard = (mask[mask_index] & (1 << mask_shift)); |
|
|
|
cutlass::arch::global_load<AccessType, pack_size_in_byte>( |
|
|
|
frag_ptr[frag_idx], |
|
|
|
reinterpret_cast<void*>(pointer_ + stride[j]), guard); |
|
|
|
} |
|
|
|
pointer_ += chan_stride_in_elements; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
MEGDNN_DEVICE __forceinline__ void store(const Fragment& frag) { |
|
|
|
const AccessType* frag_ptr = reinterpret_cast<const AccessType*>(&frag); |
|
|
|
Type* pointer_ = pointer; |
|
|
|
#pragma unroll |
|
|
|
for (int i = 0; i < chan_blk; i += pack_size) { |
|
|
|
#pragma unroll |
|
|
|
for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { |
|
|
|
int frag_idx = i / pack_size * |
|
|
|
(lane_size_in_type / pack_size_in_type) + |
|
|
|
j; |
|
|
|
int mask_index = (frag_idx >> 5); |
|
|
|
int mask_shift = (frag_idx & 0x1f); |
|
|
|
bool guard = (mask[mask_index] & (1 << mask_shift)); |
|
|
|
cutlass::arch::global_store<AccessType, pack_size_in_byte>( |
|
|
|
frag_ptr[frag_idx], |
|
|
|
reinterpret_cast<void*>(pointer_ + stride[j]), guard); |
|
|
|
} |
|
|
|
pointer_ += chan_stride_in_elements; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
MEGDNN_DEVICE __forceinline__ void advance() { |
|
|
|
pointer += (chan_blk / pack_size) * chan_stride_in_elements; |
|
|
|
channel -= chan_blk; |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
Type* pointer; |
|
|
|
int chan_stride_in_elements; |
|
|
|
int channel; |
|
|
|
int bound; |
|
|
|
int div; |
|
|
|
uint32_t mul; |
|
|
|
uint32_t shr; |
|
|
|
uint32_t mask[mask_size]; |
|
|
|
size_t stride[accesses]; |
|
|
|
}; |
|
|
|
|
|
|
|
template <bool padding_, typename Type_, int pack_size_, int chan_blk_, |
|
|
|
int width_, int size_nbits_> |
|
|
|
struct TensorIteratorPolicy; |
|
|
|
template <typename Type_, int pack_size_, int chan_blk_, int width_, |
|
|
|
int size_nbits_> |
|
|
|
struct TensorIteratorPolicy<true, Type_, pack_size_, chan_blk_, width_, |
|
|
|
size_nbits_> { |
|
|
|
using TensorIterator = |
|
|
|
MaskedTensorIteratorOverChannel<Type_, pack_size_, chan_blk_, |
|
|
|
width_, size_nbits_>; |
|
|
|
}; |
|
|
|
template <typename Type_, int pack_size_, int chan_blk_, int width_, |
|
|
|
int size_nbits_> |
|
|
|
struct TensorIteratorPolicy<false, Type_, pack_size_, chan_blk_, width_, |
|
|
|
size_nbits_> { |
|
|
|
using TensorIterator = |
|
|
|
TensorIteratorOverChannel<Type_, pack_size_, chan_blk_, width_, |
|
|
|
size_nbits_>; |
|
|
|
}; |
|
|
|
|
|
|
|
template <typename SrcIterator_, typename DstIterator_, typename Transpose_, |
|
|
|
typename CudaPostProcess_> |
|
|
|
struct RelayoutProblem { |
|
|
|
using SrcIterator = SrcIterator_; |
|
|
|
using DstIterator = DstIterator_; |
|
|
|
using Transpose = Transpose_; |
|
|
|
using CudaPostProcess = CudaPostProcess_; |
|
|
|
MEGDNN_STATIC_ASSERT(SrcIterator::chan_blk == DstIterator::chan_blk, |
|
|
|
"channel block mismatch"); |
|
|
|
MEGDNN_STATIC_ASSERT(SrcIterator::width == DstIterator::width, |
|
|
|
"width block mismatch"); |
|
|
|
MEGDNN_STATIC_ASSERT(SrcIterator::size_nbits == DstIterator::size_nbits, |
|
|
|
"size in bits of elements mismatch"); |
|
|
|
static constexpr int pack_chan = SrcIterator::chan_blk; |
|
|
|
static constexpr int pack_width = SrcIterator::width; |
|
|
|
using DnnSrcType = typename CudaPostProcess::SrcType; |
|
|
|
using DnnDstType = typename CudaPostProcess::DstType; |
|
|
|
struct Param { |
|
|
|
SrcIterator src_iterator; |
|
|
|
DstIterator dst_iterator; |
|
|
|
CudaPostProcess post_process; |
|
|
|
int n_stride_src; |
|
|
|
int n_stride_dst; |
|
|
|
int batch_size; |
|
|
|
int channels; |
|
|
|
int hw; |
|
|
|
MEGDNN_HOST MEGDNN_DEVICE Param(SrcIterator src_iterator_, |
|
|
|
DstIterator dst_iterator_, |
|
|
|
CudaPostProcess post_process_, |
|
|
|
int n_stride_src_, int n_stride_dst_, |
|
|
|
int batch_size_, int channels_, int hw_) |
|
|
|
: src_iterator{src_iterator_}, |
|
|
|
dst_iterator{dst_iterator_}, |
|
|
|
post_process{post_process_}, |
|
|
|
n_stride_src{n_stride_src_}, |
|
|
|
n_stride_dst{n_stride_dst_}, |
|
|
|
batch_size{batch_size_}, |
|
|
|
channels{channels_}, |
|
|
|
hw{hw_} {} |
|
|
|
}; |
|
|
|
}; |
|
|
|
|
|
|
|
template <typename RelayoutProblem_> |
|
|
|
__global__ void relayout_kern(typename RelayoutProblem_::Param param) { |
|
|
|
using SrcIterator = typename RelayoutProblem_::SrcIterator; |
|
|
|
using DstIterator = typename RelayoutProblem_::DstIterator; |
|
|
|
static constexpr int pack_chan = RelayoutProblem_::pack_chan; |
|
|
|
static constexpr int pack_width = RelayoutProblem_::pack_width; |
|
|
|
const int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; |
|
|
|
const int thread_offset = thread_idx * pack_width; |
|
|
|
const int hw_idx = (thread_offset % param.hw); |
|
|
|
const int nc_blks = thread_offset / param.hw; |
|
|
|
const int c_blks = (param.channels + pack_chan - 1) / pack_chan; |
|
|
|
const int n_idx = nc_blks / c_blks; |
|
|
|
const int c_blk_idx = nc_blks % c_blks; |
|
|
|
const int c_idx = c_blk_idx * pack_chan; |
|
|
|
if (n_idx < param.batch_size) { |
|
|
|
const int src_offset = n_idx * param.n_stride_src; |
|
|
|
const int dst_offset = n_idx * param.n_stride_dst; |
|
|
|
param.src_iterator.add_pointer_offset(src_offset); |
|
|
|
param.dst_iterator.add_pointer_offset(dst_offset); |
|
|
|
param.src_iterator.initialize(c_idx, hw_idx); |
|
|
|
param.dst_iterator.initialize(c_idx, hw_idx); |
|
|
|
typename SrcIterator::Fragment src_frag; |
|
|
|
typename DstIterator::Fragment dst_frag; |
|
|
|
param.src_iterator.load(src_frag); |
|
|
|
RelayoutProblem_::Transpose::trans( |
|
|
|
reinterpret_cast<typename SrcIterator::Fragment&>(dst_frag), |
|
|
|
src_frag, param.post_process); |
|
|
|
param.dst_iterator.store(dst_frag); |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
@@ -1175,21 +1381,23 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( |
|
|
|
"Unsupport pack size(pack_oc:%d, src:%s, dst:%s)", pack_oc, |
|
|
|
stype.name(), dtype.name()); |
|
|
|
#undef DEF |
|
|
|
const int in_n = src.layout[0]; |
|
|
|
const int out_n = dst.layout[0]; |
|
|
|
const int ic = src.layout[1]; |
|
|
|
const int h = src.layout[2]; |
|
|
|
const int w = src.layout[3]; |
|
|
|
const int oc = dst.layout[1] * pack_oc; |
|
|
|
const int hw = h * w; |
|
|
|
const int ocpg = oc / group; |
|
|
|
// stride in byte |
|
|
|
const int n_stride_src = src_layout.dtype.size(src_layout.stride[0]); |
|
|
|
const int ic_stride = src_layout.dtype.size(src_layout.stride[1]); |
|
|
|
const int n_stride_dst = dst_layout.dtype.size(dst_layout.stride[0]); |
|
|
|
const int oc_stride = dst_layout.dtype.size(dst_layout.stride[1]); |
|
|
|
// no padding |
|
|
|
if (src.layout.stride[2] == static_cast<ptrdiff_t>(src.layout[3])) { |
|
|
|
const int in_n = src.layout[0]; |
|
|
|
const int out_n = dst.layout[0]; |
|
|
|
const int ic = src.layout[1]; |
|
|
|
const int h = src.layout[2]; |
|
|
|
const int w = src.layout[3]; |
|
|
|
const int oc = dst.layout[1] * pack_oc; |
|
|
|
const int hw = h * w; |
|
|
|
const int ocpg = oc / group; |
|
|
|
// stride in byte |
|
|
|
const int n_stride_src = src_layout.dtype.size(src_layout.stride[0]); |
|
|
|
const int ic_stride = src_layout.dtype.size(src_layout.stride[1]); |
|
|
|
const int n_stride_dst = dst_layout.dtype.size(dst_layout.stride[0]); |
|
|
|
const int oc_stride = dst_layout.dtype.size(dst_layout.stride[1]); |
|
|
|
|
|
|
|
bool same_scale = src_scale == dst_scale; |
|
|
|
bool same_scale = src_scale == dst_scale; |
|
|
|
#define DISPATCH_RAW(_same_scale, _pack_w, _pack_oc, _src_type, _dst_type, \ |
|
|
|
_src_c_type, _dst_c_type, _size_nbits) \ |
|
|
|
if (same_scale == _same_scale && hw % _pack_w == 0 && \ |
|
|
@@ -1225,19 +1433,95 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( |
|
|
|
DISPATCH_RAW(false, 8, 64, _src_type, _dst_type, char, char, 4); \ |
|
|
|
DISPATCH_RAW(true, 2, 64, _src_type, _dst_type, char, char, 4); \ |
|
|
|
DISPATCH_RAW(false, 2, 64, _src_type, _dst_type, char, char, 4); |
|
|
|
DISPATCH_INT(QuantizedS32, QuantizedS32); |
|
|
|
DISPATCH_BYTE(Uint8, QuantizedS8); |
|
|
|
DISPATCH_BYTE(Quantized8Asymm, QuantizedS8); |
|
|
|
DISPATCH_BYTE(QuantizedS8, QuantizedS8); |
|
|
|
DISPATCH_4BITS(QuantizedS4, QuantizedS4); |
|
|
|
DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); |
|
|
|
DISPATCH_INT(QuantizedS32, QuantizedS32); |
|
|
|
DISPATCH_BYTE(Uint8, QuantizedS8); |
|
|
|
DISPATCH_BYTE(Quantized8Asymm, QuantizedS8); |
|
|
|
DISPATCH_BYTE(QuantizedS8, QuantizedS8); |
|
|
|
DISPATCH_4BITS(QuantizedS4, QuantizedS4); |
|
|
|
DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); |
|
|
|
#undef DISPATCH_4BITS |
|
|
|
#undef DISPATCH_BYTE |
|
|
|
#undef DISPATCH_INT |
|
|
|
#undef DISPATCH_RAW |
|
|
|
megdnn_assert(false, |
|
|
|
"Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", |
|
|
|
stype.name(), dtype.name(), h, w); |
|
|
|
megdnn_assert( |
|
|
|
false, |
|
|
|
"Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", |
|
|
|
stype.name(), dtype.name(), h, w); |
|
|
|
} else { |
|
|
|
megdnn_assert(src_layout.dtype.is_low_bit()); |
|
|
|
int n = src.layout[0]; |
|
|
|
int c = src.layout[1]; |
|
|
|
int h = src.layout[2]; |
|
|
|
// align to byte |
|
|
|
int w = src.layout[3]; |
|
|
|
int w_pad = DIVUP(w, 2) * 2; |
|
|
|
int hw = h * w_pad; |
|
|
|
int n_stride_src = src_layout.stride[0]; |
|
|
|
int ic_stride = src_layout.stride[1]; |
|
|
|
int n_stride_dst = dst_layout.stride[0]; |
|
|
|
int oc_stride = dst_layout.stride[1]; |
|
|
|
int problem_size = n * (c / pack_oc) * hw; |
|
|
|
bool same_scale = src_scale == dst_scale; |
|
|
|
#define DISPATCH_RAW(_same_scale, _pack_w, _pack_oc, _src_type, _dst_type, \ |
|
|
|
_src_c_type, _dst_c_type, _size_nbits) \ |
|
|
|
if (same_scale == _same_scale && hw % _pack_w == 0 && \ |
|
|
|
stype.enumv().ev == DTypeEnum::Ev::_src_type && \ |
|
|
|
dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ |
|
|
|
using InnerDtype_ = typename DTypeRWHelper< \ |
|
|
|
typename DTypeTrait<dtype::_src_type>::ctype, \ |
|
|
|
_pack_w>::InnerDtype; \ |
|
|
|
using SrcIterator_ = \ |
|
|
|
TensorIteratorOverChannel<InnerDtype_, 1, _pack_oc, _pack_w, \ |
|
|
|
_size_nbits>; \ |
|
|
|
using DstIterator_ = MaskedTensorIteratorOverChannel< \ |
|
|
|
_dst_c_type, _pack_oc, _pack_oc, _pack_w, _size_nbits>; \ |
|
|
|
using CudaPostProcess_ = \ |
|
|
|
CudaPostProcess<dtype::_src_type, dtype::_dst_type, \ |
|
|
|
_same_scale>; \ |
|
|
|
using Transpose_ = \ |
|
|
|
Translayout<_pack_w, _pack_oc, _src_c_type, dtype::_src_type, \ |
|
|
|
dtype::_dst_type, _same_scale>; \ |
|
|
|
using RelayoutProblem_ = \ |
|
|
|
RelayoutProblem<SrcIterator_, DstIterator_, Transpose_, \ |
|
|
|
CudaPostProcess_>; \ |
|
|
|
n_stride_src = n_stride_src * _size_nbits / (8 * sizeof(InnerDtype_)); \ |
|
|
|
ic_stride = ic_stride * _size_nbits / (8 * sizeof(InnerDtype_)); \ |
|
|
|
n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(_dst_c_type)); \ |
|
|
|
oc_stride = oc_stride * _size_nbits / (8 * sizeof(_dst_c_type)); \ |
|
|
|
typename RelayoutProblem_::Param param{ \ |
|
|
|
SrcIterator_{(InnerDtype_*)src.raw_ptr, ic_stride, c, w, \ |
|
|
|
w_pad}, \ |
|
|
|
DstIterator_{(_dst_c_type*)dst.raw_ptr, oc_stride, c, w, \ |
|
|
|
w_pad}, \ |
|
|
|
CudaPostProcess_{src_scale, src_zero_point, dst_scale, \ |
|
|
|
dst_zero_point}, \ |
|
|
|
n_stride_src, \ |
|
|
|
n_stride_dst, \ |
|
|
|
n, \ |
|
|
|
c, \ |
|
|
|
hw}; \ |
|
|
|
auto kernel = relayout_kern<RelayoutProblem_>; \ |
|
|
|
int nr_threads = query_blocksize_for_kernel(kernel); \ |
|
|
|
nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \ |
|
|
|
const dim3 block_dim(DIVUP(problem_size, nr_threads* _pack_w)); \ |
|
|
|
const dim3 thread_dim(nr_threads); \ |
|
|
|
return kernel<<<block_dim, thread_dim, 0, stream>>>(param); \ |
|
|
|
} |
|
|
|
#define DISPATCH_4BITS(_src_type, _dst_type) \ |
|
|
|
DISPATCH_RAW(true, 8, 64, _src_type, _dst_type, char, char, 4); \ |
|
|
|
DISPATCH_RAW(false, 8, 64, _src_type, _dst_type, char, char, 4); \ |
|
|
|
DISPATCH_RAW(true, 2, 64, _src_type, _dst_type, char, char, 4); \ |
|
|
|
DISPATCH_RAW(false, 2, 64, _src_type, _dst_type, char, char, 4); |
|
|
|
DISPATCH_4BITS(QuantizedS4, QuantizedS4); |
|
|
|
DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); |
|
|
|
#undef DISPATCH_4BITS |
|
|
|
#undef DISPATCH_RAW |
|
|
|
megdnn_assert( |
|
|
|
false, |
|
|
|
"Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", |
|
|
|
stype.name(), dtype.name(), h, w); |
|
|
|
} |
|
|
|
after_kernel_launch(); |
|
|
|
} |
|
|
|
|
|
|
|
bool relayout_format::relayout_format_cuda_usable( |
|
|
@@ -1283,43 +1567,77 @@ void relayout_format::relayout_format_cuda_nchwx_nchw( |
|
|
|
// clang-format on |
|
|
|
megdnn_assert(pack_ic == 64, "Unsupport pack size(pack_ic:%d)", pack_ic); |
|
|
|
#undef DEF |
|
|
|
const int n = src.layout[0]; |
|
|
|
const int c = src.layout[1] * pack_ic; |
|
|
|
const int h = src.layout[2]; |
|
|
|
int n = src.layout[0]; |
|
|
|
int c = src.layout[1] * pack_ic; |
|
|
|
int h = src.layout[2]; |
|
|
|
// align to byte |
|
|
|
const int w = src.layout[3]; |
|
|
|
const int hw = h * w; |
|
|
|
const int n_stride_src = src_layout.dtype.size(src_layout.stride[0]); |
|
|
|
const int ic_stride = src_layout.dtype.size(src_layout.stride[1]); |
|
|
|
const int n_stride_dst = dst_layout.dtype.size(dst_layout.stride[0]); |
|
|
|
const int oc_stride = dst_layout.dtype.size(dst_layout.stride[1]); |
|
|
|
int w = src.layout[3]; |
|
|
|
int w_pad = DIVUP(w, 2) * 2; |
|
|
|
int hw = h * w_pad; |
|
|
|
int n_stride_src = src_layout.stride[0]; |
|
|
|
int ic_stride = src_layout.stride[1]; |
|
|
|
int n_stride_dst = dst_layout.stride[0]; |
|
|
|
int oc_stride = dst_layout.stride[1]; |
|
|
|
int problem_size = n * (c / pack_ic) * hw; |
|
|
|
|
|
|
|
bool same_scale = src_scale == dst_scale; |
|
|
|
#define DISPATCH_RAW(_same_scale, _pack_w, _pack_oc, _src_type, _dst_type, \ |
|
|
|
_src_c_type, _dst_c_type, _size_nbits) \ |
|
|
|
if (same_scale == _same_scale && hw % _pack_w == 0 && \ |
|
|
|
stype.enumv().ev == DTypeEnum::Ev::_src_type && \ |
|
|
|
dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ |
|
|
|
auto kernel = \ |
|
|
|
kern_nchwx_nchw<_pack_w, _pack_oc, _same_scale, _src_c_type, \ |
|
|
|
_dst_c_type, dtype::_src_type, \ |
|
|
|
dtype::_dst_type, _size_nbits>; \ |
|
|
|
int nr_threads = query_blocksize_for_kernel(kernel); \ |
|
|
|
const dim3 block_dim(DIVUP(hw, nr_threads* _pack_w), n); \ |
|
|
|
const dim3 thread_dim(nr_threads); \ |
|
|
|
return kernel<<<block_dim, thread_dim, 0, stream>>>( \ |
|
|
|
(_src_c_type*)src.raw_ptr, (_dst_c_type*)dst.raw_ptr, c, hw, \ |
|
|
|
n_stride_src, ic_stride, n_stride_dst, oc_stride, \ |
|
|
|
CudaPostProcess<dtype::_src_type, dtype::_dst_type, \ |
|
|
|
_same_scale>(src_scale, src_zero_point, \ |
|
|
|
dst_scale, dst_zero_point), \ |
|
|
|
src_zero_point); \ |
|
|
|
bool padding = w % 2 != 0; |
|
|
|
#define DISPATCH_RAW(_padding, _same_scale, _pack_w, _pack_oc, _src_type, \ |
|
|
|
_dst_type, _src_c_type, _dst_c_type, _size_nbits) \ |
|
|
|
if (padding == _padding && same_scale == _same_scale && \ |
|
|
|
hw % _pack_w == 0 && stype.enumv().ev == DTypeEnum::Ev::_src_type && \ |
|
|
|
dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ |
|
|
|
using SrcIterator_ = \ |
|
|
|
typename TensorIteratorPolicy<_padding, _src_c_type, _pack_oc, \ |
|
|
|
_pack_oc, _pack_w, \ |
|
|
|
_size_nbits>::TensorIterator; \ |
|
|
|
using InnerDtype_ = typename DTypeRWHelper< \ |
|
|
|
typename DTypeTrait<dtype::_src_type>::ctype, \ |
|
|
|
_pack_w>::InnerDtype; \ |
|
|
|
using DstIterator_ = \ |
|
|
|
TensorIteratorOverChannel<InnerDtype_, 1, _pack_oc, _pack_w, \ |
|
|
|
_size_nbits>; \ |
|
|
|
using CudaPostProcess_ = \ |
|
|
|
CudaPostProcess<dtype::_src_type, dtype::_dst_type, \ |
|
|
|
_same_scale>; \ |
|
|
|
using Transpose_ = \ |
|
|
|
Translayout<_pack_oc, _pack_w, _src_c_type, dtype::_src_type, \ |
|
|
|
dtype::_dst_type, _same_scale>; \ |
|
|
|
using RelayoutProblem_ = \ |
|
|
|
RelayoutProblem<SrcIterator_, DstIterator_, Transpose_, \ |
|
|
|
CudaPostProcess_>; \ |
|
|
|
n_stride_src = n_stride_src * _size_nbits / (8 * sizeof(_src_c_type)); \ |
|
|
|
ic_stride = ic_stride * _size_nbits / (8 * sizeof(_src_c_type)); \ |
|
|
|
n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(InnerDtype_)); \ |
|
|
|
oc_stride = oc_stride * _size_nbits / (8 * sizeof(InnerDtype_)); \ |
|
|
|
typename RelayoutProblem_::Param param{ \ |
|
|
|
SrcIterator_{(_src_c_type*)src.raw_ptr, ic_stride, c, w, \ |
|
|
|
w_pad}, \ |
|
|
|
DstIterator_{(InnerDtype_*)dst.raw_ptr, oc_stride, c, w, \ |
|
|
|
w_pad}, \ |
|
|
|
CudaPostProcess_{src_scale, src_zero_point, dst_scale, \ |
|
|
|
dst_zero_point}, \ |
|
|
|
n_stride_src, \ |
|
|
|
n_stride_dst, \ |
|
|
|
n, \ |
|
|
|
c, \ |
|
|
|
hw}; \ |
|
|
|
auto kernel = relayout_kern<RelayoutProblem_>; \ |
|
|
|
int nr_threads = query_blocksize_for_kernel(kernel); \ |
|
|
|
nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \ |
|
|
|
const dim3 block_dim(DIVUP(problem_size, nr_threads* _pack_w)); \ |
|
|
|
const dim3 thread_dim(nr_threads); \ |
|
|
|
return kernel<<<block_dim, thread_dim, 0, stream>>>(param); \ |
|
|
|
} |
|
|
|
#define DISPATCH_4BITS(_src_type, _dst_type) \ |
|
|
|
DISPATCH_RAW(true, 8, 64, _src_type, _dst_type, char, char, 4); \ |
|
|
|
DISPATCH_RAW(false, 8, 64, _src_type, _dst_type, char, char, 4); \ |
|
|
|
DISPATCH_RAW(true, 2, 64, _src_type, _dst_type, char, char, 4); \ |
|
|
|
DISPATCH_RAW(false, 2, 64, _src_type, _dst_type, char, char, 4); |
|
|
|
#define DISPATCH_4BITS(_src_type, _dst_type) \ |
|
|
|
DISPATCH_RAW(true, true, 8, 64, _src_type, _dst_type, char, char, 4); \ |
|
|
|
DISPATCH_RAW(true, false, 8, 64, _src_type, _dst_type, char, char, 4); \ |
|
|
|
DISPATCH_RAW(true, true, 2, 64, _src_type, _dst_type, char, char, 4); \ |
|
|
|
DISPATCH_RAW(true, false, 2, 64, _src_type, _dst_type, char, char, 4); \ |
|
|
|
DISPATCH_RAW(false, true, 8, 64, _src_type, _dst_type, char, char, 4); \ |
|
|
|
DISPATCH_RAW(false, false, 8, 64, _src_type, _dst_type, char, char, 4); \ |
|
|
|
DISPATCH_RAW(false, true, 2, 64, _src_type, _dst_type, char, char, 4); \ |
|
|
|
DISPATCH_RAW(false, false, 2, 64, _src_type, _dst_type, char, char, 4); |
|
|
|
DISPATCH_4BITS(QuantizedS4, QuantizedS4); |
|
|
|
DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); |
|
|
|
#undef DISPATCH_4BITS |
|
|
@@ -1327,6 +1645,7 @@ void relayout_format::relayout_format_cuda_nchwx_nchw( |
|
|
|
megdnn_assert(false, |
|
|
|
"Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", |
|
|
|
stype.name(), dtype.name(), h, w); |
|
|
|
after_kernel_launch(); |
|
|
|
} |
|
|
|
|
|
|
|
void relayout_format::relayout_format_cuda_nchw4_nchw( |
|
|
@@ -1344,6 +1663,7 @@ void relayout_format::relayout_format_cuda_nchw4_nchw( |
|
|
|
const dim3 thread_dim(nr_threads); |
|
|
|
kern_nchw4_nchw<<<block_dim, thread_dim, 0, stream>>>( |
|
|
|
(int8_t*)src.raw_ptr, (int8_t*)dst.raw_ptr, n, ic, oc, h, w, group); |
|
|
|
after_kernel_launch(); |
|
|
|
} |
|
|
|
|
|
|
|
void relayout_format::relayout_format_cuda_nchw_nchw4_weight( |
|
|
@@ -1372,4 +1692,5 @@ void relayout_format::relayout_format_cuda_nchw_nchw4_weight( |
|
|
|
(char*)src.raw_ptr, (char*)dst.raw_ptr, oc, ic, hw, oc_stride_src, |
|
|
|
ic_stride, oc_stride_dst, group_stride_src, group_stride_dst, 0, |
|
|
|
{}); |
|
|
|
after_kernel_launch(); |
|
|
|
} |