GitOrigin-RevId: e9ecbcf2e2
release-1.5
@@ -252,10 +252,10 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, | |||||
megdnn_assert(dst[1] % param().group == 0); | megdnn_assert(dst[1] % param().group == 0); | ||||
break; | break; | ||||
case Param::Mode::NCHW_NCHW64: | case Param::Mode::NCHW_NCHW64: | ||||
megdnn_assert(src.ndim == 4 && (src[1] % 64) == 0); | |||||
megdnn_assert(src.ndim == 4); | |||||
dst.ndim = 5; | dst.ndim = 5; | ||||
dst[0] = src[0]; | dst[0] = src[0]; | ||||
dst[1] = src[1] / 64; | |||||
dst[1] = div_ceil(src[1], 64_z); | |||||
dst[2] = src[2]; | dst[2] = src[2]; | ||||
dst[3] = src[3]; | dst[3] = src[3]; | ||||
dst[4] = 64; | dst[4] = 64; | ||||
@@ -264,7 +264,7 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, | |||||
megdnn_assert(src.ndim == 5); | megdnn_assert(src.ndim == 5); | ||||
dst.ndim = 4; | dst.ndim = 4; | ||||
dst[0] = src[0]; | dst[0] = src[0]; | ||||
dst[1] = src[1] * 64; | |||||
dst[1] = param().oc == 0 ? src[1] * 64 : param().oc; | |||||
dst[2] = src[2]; | dst[2] = src[2]; | ||||
dst[3] = src[3]; | dst[3] = src[3]; | ||||
break; | break; | ||||
@@ -483,12 +483,11 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, | |||||
case Param::Mode::NCHW4_NCHW: | case Param::Mode::NCHW4_NCHW: | ||||
// nchw to nchw4 | // nchw to nchw4 | ||||
{ | { | ||||
megdnn_assert(src.format == dst.format); | |||||
exec_workspace = | exec_workspace = | ||||
TensorLayout({src[0], src[1] * 4, src[2], src[3]}, | TensorLayout({src[0], src[1] * 4, src[2], src[3]}, | ||||
src.dtype, src.format) | |||||
.reshape({src[0], src[1], 4, src[2], src[3]}) | |||||
.dimshuffle({0, 1, 3, 4, 2}); | |||||
exec_src = src; | |||||
dst.dtype, dst.format); | |||||
exec_src = src.dimshuffle({0, 1, 4, 2, 3}); | |||||
exec_dst = dst; | exec_dst = dst; | ||||
} | } | ||||
break; | break; | ||||
@@ -658,13 +657,20 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, | |||||
case Param::Mode::NCHW_NCHW64: | case Param::Mode::NCHW_NCHW64: | ||||
// src is {N, C, H, W} | // src is {N, C, H, W} | ||||
// dst is {N, C/64, H, W, 64} | // dst is {N, C/64, H, W, 64} | ||||
exec_src = src.reshape({src[0], src[1] / 64, 64, src[2], src[3]}) | |||||
exec_workspace = TensorLayout( | |||||
{src[0], round_up(src[1], 64_z), src[2], src[3]}, | |||||
src.dtype); | |||||
exec_src = exec_workspace | |||||
.reshape({src[0], div_ceil(src[1], 64_z), 64, | |||||
src[2], src[3]}) | |||||
.dimshuffle({0, 1, 3, 4, 2}); | .dimshuffle({0, 1, 3, 4, 2}); | ||||
exec_dst = dst; | exec_dst = dst; | ||||
break; | break; | ||||
case Param::Mode::NCHW64_NCHW: | case Param::Mode::NCHW64_NCHW: | ||||
// src is {N, C/64, H, W, 64} | // src is {N, C/64, H, W, 64} | ||||
// dst is {N, C, H, W} | // dst is {N, C, H, W} | ||||
exec_workspace = TensorLayout({src[0], src[1] * 64, src[2], src[3]}, | |||||
dst.dtype); | |||||
exec_src = src.dimshuffle({0, 1, 4, 2, 3}); | exec_src = src.dimshuffle({0, 1, 4, 2, 3}); | ||||
exec_dst = dst; | exec_dst = dst; | ||||
break; | break; | ||||
@@ -0,0 +1,149 @@ | |||||
/** | |||||
* \file dnn/src/cuda/relayout_format/helper.cuh | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace relayout_format { | |||||
#define devfunc __forceinline__ __device__ | |||||
template <int size_nbits> | |||||
devfunc int make_zero(int zero_point); | |||||
template <> | |||||
devfunc int make_zero<4>(int zero_point) { | |||||
return transform_int8_to_uint4x8(zero_point, zero_point, zero_point, | |||||
zero_point, zero_point, zero_point, | |||||
zero_point, zero_point); | |||||
} | |||||
template <typename AccessType, int LoadBytes> | |||||
struct global_load_with_zero_point; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
// | |||||
// Specializations | |||||
// | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
// The redundant mov PTX instruction is used to enforce the compiler to | |||||
// initialize data to zero before ld.global | |||||
template <typename AccessType> | |||||
struct global_load_with_zero_point<AccessType, 32> { | |||||
devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||||
bool pred_guard, int zero_point) { | |||||
uint4* data = reinterpret_cast<uint4*>(&D); | |||||
asm volatile( | |||||
"{\n" | |||||
" .reg .pred p;\n" | |||||
" setp.ne.b32 p, %9, 0;\n" | |||||
" mov.b32 %0, %10;\n" | |||||
" mov.b32 %1, %10;\n" | |||||
" mov.b32 %2, %10;\n" | |||||
" mov.b32 %3, %10;\n" | |||||
" mov.b32 %4, %10;\n" | |||||
" mov.b32 %5, %10;\n" | |||||
" mov.b32 %6, %10;\n" | |||||
" mov.b32 %7, %10;\n" | |||||
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n" | |||||
" @p ld.global.v4.u32 {%4, %5, %6, %7}, [%11];\n" | |||||
"}\n" | |||||
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), | |||||
"=r"(data[0].w), "=r"(data[1].x), "=r"(data[1].y), | |||||
"=r"(data[1].z), "=r"(data[1].w) | |||||
: "l"(ptr), "r"((int)pred_guard), | |||||
"r"(reinterpret_cast<unsigned&>(zero_point)), | |||||
"l"(((uint8_t*)ptr) + 16)); | |||||
} | |||||
}; | |||||
template <typename AccessType> | |||||
struct global_load_with_zero_point<AccessType, 16> { | |||||
devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||||
bool pred_guard, int zero_point) { | |||||
uint4& data = reinterpret_cast<uint4&>(D); | |||||
asm volatile( | |||||
"{\n" | |||||
" .reg .pred p;\n" | |||||
" setp.ne.b32 p, %5, 0;\n" | |||||
" mov.b32 %0, %6;\n" | |||||
" mov.b32 %1, %6;\n" | |||||
" mov.b32 %2, %6;\n" | |||||
" mov.b32 %3, %6;\n" | |||||
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n" | |||||
"}\n" | |||||
: "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) | |||||
: "l"(ptr), "r"((int)pred_guard), | |||||
"r"(reinterpret_cast<unsigned&>(zero_point))); | |||||
} | |||||
}; | |||||
template <typename AccessType> | |||||
struct global_load_with_zero_point<AccessType, 8> { | |||||
devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||||
bool pred_guard, int zero_point) { | |||||
uint2& data = reinterpret_cast<uint2&>(D); | |||||
asm volatile( | |||||
"{\n" | |||||
" .reg .pred p;\n" | |||||
" setp.ne.b32 p, %3, 0;\n" | |||||
" mov.b32 %0, %4;\n" | |||||
" mov.b32 %1, %4;\n" | |||||
" @p ld.global.v2.u32 {%0, %1}, [%2];\n" | |||||
"}\n" | |||||
: "=r"(data.x), "=r"(data.y) | |||||
: "l"(ptr), "r"((int)pred_guard), | |||||
"r"(reinterpret_cast<unsigned&>(zero_point))); | |||||
} | |||||
}; | |||||
template <typename AccessType> | |||||
struct global_load_with_zero_point<AccessType, 4> { | |||||
devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||||
bool pred_guard, int zero_point) { | |||||
unsigned& data = reinterpret_cast<unsigned&>(D); | |||||
asm volatile( | |||||
"{\n" | |||||
" .reg .pred p;\n" | |||||
" setp.ne.b32 p, %2, 0;\n" | |||||
" mov.b32 %0, %3;\n" | |||||
" @p ld.global.u32 %0, [%1];\n" | |||||
"}\n" | |||||
: "=r"(data) | |||||
: "l"(ptr), "r"((int)pred_guard), | |||||
"r"(reinterpret_cast<unsigned&>(zero_point))); | |||||
} | |||||
}; | |||||
template <typename AccessType> | |||||
struct global_load_with_zero_point<AccessType, 1> { | |||||
devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||||
bool pred_guard, int zero_point) { | |||||
if (pred_guard) | |||||
D = *(reinterpret_cast<AccessType const*>(ptr)); | |||||
else { | |||||
unsigned uv = reinterpret_cast<unsigned&>(zero_point); | |||||
uint8_t& data = reinterpret_cast<uint8_t&>(D); | |||||
data = uv & 0xff; | |||||
} | |||||
} | |||||
}; | |||||
#undef devfunc | |||||
} // namespace relayout_format | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -18,6 +18,7 @@ | |||||
#pragma GCC diagnostic pop | #pragma GCC diagnostic pop | ||||
#include "src/cuda/query_blocksize.cuh" | #include "src/cuda/query_blocksize.cuh" | ||||
#include "src/cuda/relayout_format/relayout_format.cuh" | #include "src/cuda/relayout_format/relayout_format.cuh" | ||||
#include "src/cuda/relayout_format/helper.cuh" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -728,17 +729,18 @@ struct Translayout<64, 2, SrcType, dtype::Quantized4Asymm, | |||||
#undef pack | #undef pack | ||||
template <typename DstType> | template <typename DstType> | ||||
inline __device__ DstType make_zero_pad(const char zero_point) { | |||||
inline __device__ DstType make_zero_pad(const uint8_t zero_point) { | |||||
return zero_point; | return zero_point; | ||||
} | } | ||||
template <> | template <> | ||||
inline __device__ char4 make_zero_pad<char4>(const char zero_point) { | |||||
return {zero_point, zero_point, zero_point, zero_point}; | |||||
inline __device__ char4 make_zero_pad<char4>(const uint8_t zero_point) { | |||||
char izp = reinterpret_cast<const char&>(zero_point); | |||||
return {izp, izp, izp, izp}; | |||||
} | } | ||||
template <> | template <> | ||||
inline __device__ int4 make_zero_pad<int4>(const char zero_point) { | |||||
inline __device__ int4 make_zero_pad<int4>(const uint8_t zero_point) { | |||||
return {zero_point, zero_point, zero_point, zero_point}; | return {zero_point, zero_point, zero_point, zero_point}; | ||||
} | } | ||||
@@ -767,7 +769,7 @@ inline __device__ void write_helper<array_wrapper<uint8_t, 32>>( | |||||
: "l"(ptr_), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), | : "l"(ptr_), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), | ||||
"r"(data[0].w), "l"(((uint8_t*)ptr_) + 16), "r"(data[1].x), | "r"(data[0].w), "l"(((uint8_t*)ptr_) + 16), "r"(data[1].x), | ||||
"r"(data[1].y), "r"(data[1].z), "r"(data[1].w)); | "r"(data[1].y), "r"(data[1].z), "r"(data[1].w)); | ||||
}; | |||||
} | |||||
template <bool with_pad, int pack_w, int pack_c, bool same_scale, bool all_pad, | template <bool with_pad, int pack_w, int pack_c, bool same_scale, bool all_pad, | ||||
typename SrcType, typename DstType, typename DnnSrcType, | typename SrcType, typename DstType, typename DnnSrcType, | ||||
@@ -825,7 +827,7 @@ struct RelayoutKern { | |||||
const SrcType* src, DstType* dst, const int ic_stride, | const SrcType* src, DstType* dst, const int ic_stride, | ||||
const int remain_ic, | const int remain_ic, | ||||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | ||||
const char zero_point) { | |||||
const uint8_t zero_point) { | |||||
InnerDtype read_channel[pack_c]; | InnerDtype read_channel[pack_c]; | ||||
if (all_pad) { | if (all_pad) { | ||||
const InnerDtype zero_pad = make_zero_pad<InnerDtype>(zero_point); | const InnerDtype zero_pad = make_zero_pad<InnerDtype>(zero_point); | ||||
@@ -855,7 +857,7 @@ __global__ void kern_nchw_nchwx( | |||||
const SrcType* src, DstType* dst, int in_n, int ic, int ihw, | const SrcType* src, DstType* dst, int in_n, int ic, int ihw, | ||||
int n_stride_src, int ic_stride, int n_stride_dst, int oc_stride, | int n_stride_src, int ic_stride, int n_stride_dst, int oc_stride, | ||||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale> post_process, | CudaPostProcess<DnnSrcType, DnnDstType, same_scale> post_process, | ||||
const char zero_point, const int group, const int ocpg) { | |||||
const uint8_t zero_point, const int group, const int ocpg) { | |||||
static constexpr int size_src_type = sizeof(SrcType); | static constexpr int size_src_type = sizeof(SrcType); | ||||
static constexpr int size_dst_type = sizeof(DstType); | static constexpr int size_dst_type = sizeof(DstType); | ||||
#ifndef MEGDNN_COMMA | #ifndef MEGDNN_COMMA | ||||
@@ -1072,6 +1074,7 @@ public: | |||||
MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { | MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { | ||||
pointer += (c_idx / pack_size) * chan_stride_in_elements + | pointer += (c_idx / pack_size) * chan_stride_in_elements + | ||||
hw_idx * pack_size * size_nbits / (8 * sizeof(Type)); | hw_idx * pack_size * size_nbits / (8 * sizeof(Type)); | ||||
channel -= c_idx; | |||||
} | } | ||||
MEGDNN_DEVICE __forceinline__ void add_pointer_offset( | MEGDNN_DEVICE __forceinline__ void add_pointer_offset( | ||||
@@ -1079,7 +1082,7 @@ public: | |||||
pointer += offset_in_type; | pointer += offset_in_type; | ||||
} | } | ||||
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag) { | |||||
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) { | |||||
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); | AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); | ||||
Type* pointer_ = pointer; | Type* pointer_ = pointer; | ||||
#pragma unroll | #pragma unroll | ||||
@@ -1090,11 +1093,12 @@ public: | |||||
(lane_size_in_type / pack_size_in_type) + | (lane_size_in_type / pack_size_in_type) + | ||||
j; | j; | ||||
bool guard = i < channel; | bool guard = i < channel; | ||||
cutlass::arch::global_load<AccessType, pack_size_in_byte>( | |||||
relayout_format::global_load_with_zero_point<AccessType, | |||||
pack_size_in_byte>( | |||||
frag_ptr[frag_idx], | frag_ptr[frag_idx], | ||||
reinterpret_cast<void*>(pointer_ + | reinterpret_cast<void*>(pointer_ + | ||||
j * pack_size_in_type), | j * pack_size_in_type), | ||||
guard); | |||||
guard, zero_point); | |||||
} | } | ||||
pointer_ += chan_stride_in_elements; | pointer_ += chan_stride_in_elements; | ||||
} | } | ||||
@@ -1173,6 +1177,7 @@ public: | |||||
MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { | MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { | ||||
pointer += (c_idx / pack_size) * chan_stride_in_elements; | pointer += (c_idx / pack_size) * chan_stride_in_elements; | ||||
channel -= c_idx; | |||||
#pragma unroll | #pragma unroll | ||||
for (int i = 0; i < mask_size; ++i) { | for (int i = 0; i < mask_size; ++i) { | ||||
mask[i] = 0; | mask[i] = 0; | ||||
@@ -1201,7 +1206,7 @@ public: | |||||
pointer += offset_in_type; | pointer += offset_in_type; | ||||
} | } | ||||
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag) { | |||||
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) { | |||||
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); | AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); | ||||
Type* pointer_ = pointer; | Type* pointer_ = pointer; | ||||
#pragma unroll | #pragma unroll | ||||
@@ -1214,9 +1219,11 @@ public: | |||||
int mask_index = (frag_idx >> 5); | int mask_index = (frag_idx >> 5); | ||||
int mask_shift = (frag_idx & 0x1f); | int mask_shift = (frag_idx & 0x1f); | ||||
bool guard = (mask[mask_index] & (1 << mask_shift)); | bool guard = (mask[mask_index] & (1 << mask_shift)); | ||||
cutlass::arch::global_load<AccessType, pack_size_in_byte>( | |||||
relayout_format::global_load_with_zero_point<AccessType, | |||||
pack_size_in_byte>( | |||||
frag_ptr[frag_idx], | frag_ptr[frag_idx], | ||||
reinterpret_cast<void*>(pointer_ + stride[j]), guard); | |||||
reinterpret_cast<void*>(pointer_ + stride[j]), guard, | |||||
zero_point); | |||||
} | } | ||||
pointer_ += chan_stride_in_elements; | pointer_ += chan_stride_in_elements; | ||||
} | } | ||||
@@ -1306,11 +1313,13 @@ struct RelayoutProblem { | |||||
int batch_size; | int batch_size; | ||||
int channels; | int channels; | ||||
int hw; | int hw; | ||||
int zero_point; | |||||
MEGDNN_HOST MEGDNN_DEVICE Param(SrcIterator src_iterator_, | MEGDNN_HOST MEGDNN_DEVICE Param(SrcIterator src_iterator_, | ||||
DstIterator dst_iterator_, | DstIterator dst_iterator_, | ||||
CudaPostProcess post_process_, | CudaPostProcess post_process_, | ||||
int n_stride_src_, int n_stride_dst_, | int n_stride_src_, int n_stride_dst_, | ||||
int batch_size_, int channels_, int hw_) | |||||
int batch_size_, int channels_, int hw_, | |||||
int zero_point_) | |||||
: src_iterator{src_iterator_}, | : src_iterator{src_iterator_}, | ||||
dst_iterator{dst_iterator_}, | dst_iterator{dst_iterator_}, | ||||
post_process{post_process_}, | post_process{post_process_}, | ||||
@@ -1318,7 +1327,8 @@ struct RelayoutProblem { | |||||
n_stride_dst{n_stride_dst_}, | n_stride_dst{n_stride_dst_}, | ||||
batch_size{batch_size_}, | batch_size{batch_size_}, | ||||
channels{channels_}, | channels{channels_}, | ||||
hw{hw_} {} | |||||
hw{hw_}, | |||||
zero_point{zero_point_} {} | |||||
}; | }; | ||||
}; | }; | ||||
@@ -1345,7 +1355,9 @@ __global__ void relayout_kern(typename RelayoutProblem_::Param param) { | |||||
param.dst_iterator.initialize(c_idx, hw_idx); | param.dst_iterator.initialize(c_idx, hw_idx); | ||||
typename SrcIterator::Fragment src_frag; | typename SrcIterator::Fragment src_frag; | ||||
typename DstIterator::Fragment dst_frag; | typename DstIterator::Fragment dst_frag; | ||||
param.src_iterator.load(src_frag); | |||||
int zp = relayout_format::make_zero<SrcIterator::size_nbits>( | |||||
param.zero_point); | |||||
param.src_iterator.load(src_frag, zp); | |||||
RelayoutProblem_::Transpose::trans( | RelayoutProblem_::Transpose::trans( | ||||
reinterpret_cast<typename SrcIterator::Fragment&>(dst_frag), | reinterpret_cast<typename SrcIterator::Fragment&>(dst_frag), | ||||
src_frag, param.post_process); | src_frag, param.post_process); | ||||
@@ -1382,7 +1394,8 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( | |||||
stype.name(), dtype.name()); | stype.name(), dtype.name()); | ||||
#undef DEF | #undef DEF | ||||
// no padding | // no padding | ||||
if (src.layout.stride[2] == static_cast<ptrdiff_t>(src.layout[3])) { | |||||
if (stype.enumv().ev != DTypeEnum::Ev::QuantizedS4 && | |||||
stype.enumv().ev != DTypeEnum::Ev::Quantized4Asymm) { | |||||
const int in_n = src.layout[0]; | const int in_n = src.layout[0]; | ||||
const int out_n = dst.layout[0]; | const int out_n = dst.layout[0]; | ||||
const int ic = src.layout[1]; | const int ic = src.layout[1]; | ||||
@@ -1428,18 +1441,10 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( | |||||
DISPATCH_RAW(false, 4, 4, _src_type, _dst_type, char, char, 8); \ | DISPATCH_RAW(false, 4, 4, _src_type, _dst_type, char, char, 8); \ | ||||
DISPATCH_RAW(true, 1, 4, _src_type, _dst_type, char, char, 8); \ | DISPATCH_RAW(true, 1, 4, _src_type, _dst_type, char, char, 8); \ | ||||
DISPATCH_RAW(false, 1, 4, _src_type, _dst_type, char, char, 8); | DISPATCH_RAW(false, 1, 4, _src_type, _dst_type, char, char, 8); | ||||
#define DISPATCH_4BITS(_src_type, _dst_type) \ | |||||
DISPATCH_RAW(true, 8, 64, _src_type, _dst_type, char, char, 4); \ | |||||
DISPATCH_RAW(false, 8, 64, _src_type, _dst_type, char, char, 4); \ | |||||
DISPATCH_RAW(true, 2, 64, _src_type, _dst_type, char, char, 4); \ | |||||
DISPATCH_RAW(false, 2, 64, _src_type, _dst_type, char, char, 4); | |||||
DISPATCH_INT(QuantizedS32, QuantizedS32); | DISPATCH_INT(QuantizedS32, QuantizedS32); | ||||
DISPATCH_BYTE(Uint8, QuantizedS8); | DISPATCH_BYTE(Uint8, QuantizedS8); | ||||
DISPATCH_BYTE(Quantized8Asymm, QuantizedS8); | DISPATCH_BYTE(Quantized8Asymm, QuantizedS8); | ||||
DISPATCH_BYTE(QuantizedS8, QuantizedS8); | DISPATCH_BYTE(QuantizedS8, QuantizedS8); | ||||
DISPATCH_4BITS(QuantizedS4, QuantizedS4); | |||||
DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); | |||||
#undef DISPATCH_4BITS | |||||
#undef DISPATCH_BYTE | #undef DISPATCH_BYTE | ||||
#undef DISPATCH_INT | #undef DISPATCH_INT | ||||
#undef DISPATCH_RAW | #undef DISPATCH_RAW | ||||
@@ -1450,7 +1455,8 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( | |||||
} else { | } else { | ||||
megdnn_assert(src_layout.dtype.is_low_bit()); | megdnn_assert(src_layout.dtype.is_low_bit()); | ||||
int n = src.layout[0]; | int n = src.layout[0]; | ||||
int c = src.layout[1]; | |||||
int ic = src.layout[1]; | |||||
int oc = dst.layout[1] * 64; | |||||
int h = src.layout[2]; | int h = src.layout[2]; | ||||
// align to byte | // align to byte | ||||
int w = src.layout[3]; | int w = src.layout[3]; | ||||
@@ -1460,12 +1466,13 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( | |||||
int ic_stride = src_layout.stride[1]; | int ic_stride = src_layout.stride[1]; | ||||
int n_stride_dst = dst_layout.stride[0]; | int n_stride_dst = dst_layout.stride[0]; | ||||
int oc_stride = dst_layout.stride[1]; | int oc_stride = dst_layout.stride[1]; | ||||
int problem_size = n * (c / pack_oc) * hw; | |||||
int problem_size = n * (oc / pack_oc) * hw; | |||||
bool same_scale = src_scale == dst_scale; | bool same_scale = src_scale == dst_scale; | ||||
#define DISPATCH_RAW(_same_scale, _pack_w, _pack_oc, _src_type, _dst_type, \ | |||||
_src_c_type, _dst_c_type, _size_nbits) \ | |||||
if (same_scale == _same_scale && hw % _pack_w == 0 && \ | |||||
stype.enumv().ev == DTypeEnum::Ev::_src_type && \ | |||||
bool padding = w % 2 != 0; | |||||
#define DISPATCH_RAW(_padding, _same_scale, _pack_w, _pack_oc, _src_type, \ | |||||
_dst_type, _src_c_type, _dst_c_type, _size_nbits) \ | |||||
if (padding == _padding && same_scale == _same_scale && \ | |||||
hw % _pack_w == 0 && stype.enumv().ev == DTypeEnum::Ev::_src_type && \ | |||||
dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ | dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ | ||||
using InnerDtype_ = typename DTypeRWHelper< \ | using InnerDtype_ = typename DTypeRWHelper< \ | ||||
typename DTypeTrait<dtype::_src_type>::ctype, \ | typename DTypeTrait<dtype::_src_type>::ctype, \ | ||||
@@ -1473,8 +1480,10 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( | |||||
using SrcIterator_ = \ | using SrcIterator_ = \ | ||||
TensorIteratorOverChannel<InnerDtype_, 1, _pack_oc, _pack_w, \ | TensorIteratorOverChannel<InnerDtype_, 1, _pack_oc, _pack_w, \ | ||||
_size_nbits>; \ | _size_nbits>; \ | ||||
using DstIterator_ = MaskedTensorIteratorOverChannel< \ | |||||
_dst_c_type, _pack_oc, _pack_oc, _pack_w, _size_nbits>; \ | |||||
using DstIterator_ = \ | |||||
typename TensorIteratorPolicy<_padding, _dst_c_type, _pack_oc, \ | |||||
_pack_oc, _pack_w, \ | |||||
_size_nbits>::TensorIterator; \ | |||||
using CudaPostProcess_ = \ | using CudaPostProcess_ = \ | ||||
CudaPostProcess<dtype::_src_type, dtype::_dst_type, \ | CudaPostProcess<dtype::_src_type, dtype::_dst_type, \ | ||||
_same_scale>; \ | _same_scale>; \ | ||||
@@ -1489,17 +1498,18 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( | |||||
n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(_dst_c_type)); \ | n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(_dst_c_type)); \ | ||||
oc_stride = oc_stride * _size_nbits / (8 * sizeof(_dst_c_type)); \ | oc_stride = oc_stride * _size_nbits / (8 * sizeof(_dst_c_type)); \ | ||||
typename RelayoutProblem_::Param param{ \ | typename RelayoutProblem_::Param param{ \ | ||||
SrcIterator_{(InnerDtype_*)src.raw_ptr, ic_stride, c, w, \ | |||||
SrcIterator_{(InnerDtype_*)src.raw_ptr, ic_stride, ic, w, \ | |||||
w_pad}, \ | w_pad}, \ | ||||
DstIterator_{(_dst_c_type*)dst.raw_ptr, oc_stride, c, w, \ | |||||
DstIterator_{(_dst_c_type*)dst.raw_ptr, oc_stride, oc, w, \ | |||||
w_pad}, \ | w_pad}, \ | ||||
CudaPostProcess_{src_scale, src_zero_point, dst_scale, \ | CudaPostProcess_{src_scale, src_zero_point, dst_scale, \ | ||||
dst_zero_point}, \ | dst_zero_point}, \ | ||||
n_stride_src, \ | n_stride_src, \ | ||||
n_stride_dst, \ | n_stride_dst, \ | ||||
n, \ | n, \ | ||||
c, \ | |||||
hw}; \ | |||||
oc, \ | |||||
hw, \ | |||||
src_zero_point}; \ | |||||
auto kernel = relayout_kern<RelayoutProblem_>; \ | auto kernel = relayout_kern<RelayoutProblem_>; \ | ||||
int nr_threads = query_blocksize_for_kernel(kernel); \ | int nr_threads = query_blocksize_for_kernel(kernel); \ | ||||
nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \ | nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \ | ||||
@@ -1507,11 +1517,15 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( | |||||
const dim3 thread_dim(nr_threads); \ | const dim3 thread_dim(nr_threads); \ | ||||
return kernel<<<block_dim, thread_dim, 0, stream>>>(param); \ | return kernel<<<block_dim, thread_dim, 0, stream>>>(param); \ | ||||
} | } | ||||
#define DISPATCH_4BITS(_src_type, _dst_type) \ | |||||
DISPATCH_RAW(true, 8, 64, _src_type, _dst_type, char, char, 4); \ | |||||
DISPATCH_RAW(false, 8, 64, _src_type, _dst_type, char, char, 4); \ | |||||
DISPATCH_RAW(true, 2, 64, _src_type, _dst_type, char, char, 4); \ | |||||
DISPATCH_RAW(false, 2, 64, _src_type, _dst_type, char, char, 4); | |||||
#define DISPATCH_4BITS(_src_type, _dst_type) \ | |||||
DISPATCH_RAW(true, true, 8, 64, _src_type, _dst_type, char, char, 4); \ | |||||
DISPATCH_RAW(true, false, 8, 64, _src_type, _dst_type, char, char, 4); \ | |||||
DISPATCH_RAW(true, true, 2, 64, _src_type, _dst_type, char, char, 4); \ | |||||
DISPATCH_RAW(true, false, 2, 64, _src_type, _dst_type, char, char, 4); \ | |||||
DISPATCH_RAW(false, true, 8, 64, _src_type, _dst_type, char, char, 4); \ | |||||
DISPATCH_RAW(false, false, 8, 64, _src_type, _dst_type, char, char, 4); \ | |||||
DISPATCH_RAW(false, true, 2, 64, _src_type, _dst_type, char, char, 4); \ | |||||
DISPATCH_RAW(false, false, 2, 64, _src_type, _dst_type, char, char, 4); | |||||
DISPATCH_4BITS(QuantizedS4, QuantizedS4); | DISPATCH_4BITS(QuantizedS4, QuantizedS4); | ||||
DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); | DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); | ||||
#undef DISPATCH_4BITS | #undef DISPATCH_4BITS | ||||
@@ -1521,7 +1535,6 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( | |||||
"Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", | "Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", | ||||
stype.name(), dtype.name(), h, w); | stype.name(), dtype.name(), h, w); | ||||
} | } | ||||
after_kernel_launch(); | |||||
} | } | ||||
bool relayout_format::relayout_format_cuda_usable( | bool relayout_format::relayout_format_cuda_usable( | ||||
@@ -1568,7 +1581,7 @@ void relayout_format::relayout_format_cuda_nchwx_nchw( | |||||
megdnn_assert(pack_ic == 64, "Unsupport pack size(pack_ic:%d)", pack_ic); | megdnn_assert(pack_ic == 64, "Unsupport pack size(pack_ic:%d)", pack_ic); | ||||
#undef DEF | #undef DEF | ||||
int n = src.layout[0]; | int n = src.layout[0]; | ||||
int c = src.layout[1] * pack_ic; | |||||
int ic = src.layout[1] * pack_ic; | |||||
int h = src.layout[2]; | int h = src.layout[2]; | ||||
// align to byte | // align to byte | ||||
int w = src.layout[3]; | int w = src.layout[3]; | ||||
@@ -1578,7 +1591,8 @@ void relayout_format::relayout_format_cuda_nchwx_nchw( | |||||
int ic_stride = src_layout.stride[1]; | int ic_stride = src_layout.stride[1]; | ||||
int n_stride_dst = dst_layout.stride[0]; | int n_stride_dst = dst_layout.stride[0]; | ||||
int oc_stride = dst_layout.stride[1]; | int oc_stride = dst_layout.stride[1]; | ||||
int problem_size = n * (c / pack_ic) * hw; | |||||
int problem_size = n * (ic / pack_ic) * hw; | |||||
int oc = dst.layout[1]; | |||||
bool same_scale = src_scale == dst_scale; | bool same_scale = src_scale == dst_scale; | ||||
bool padding = w % 2 != 0; | bool padding = w % 2 != 0; | ||||
@@ -1611,17 +1625,18 @@ void relayout_format::relayout_format_cuda_nchwx_nchw( | |||||
n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(InnerDtype_)); \ | n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(InnerDtype_)); \ | ||||
oc_stride = oc_stride * _size_nbits / (8 * sizeof(InnerDtype_)); \ | oc_stride = oc_stride * _size_nbits / (8 * sizeof(InnerDtype_)); \ | ||||
typename RelayoutProblem_::Param param{ \ | typename RelayoutProblem_::Param param{ \ | ||||
SrcIterator_{(_src_c_type*)src.raw_ptr, ic_stride, c, w, \ | |||||
SrcIterator_{(_src_c_type*)src.raw_ptr, ic_stride, ic, w, \ | |||||
w_pad}, \ | w_pad}, \ | ||||
DstIterator_{(InnerDtype_*)dst.raw_ptr, oc_stride, c, w, \ | |||||
DstIterator_{(InnerDtype_*)dst.raw_ptr, oc_stride, oc, w, \ | |||||
w_pad}, \ | w_pad}, \ | ||||
CudaPostProcess_{src_scale, src_zero_point, dst_scale, \ | CudaPostProcess_{src_scale, src_zero_point, dst_scale, \ | ||||
dst_zero_point}, \ | dst_zero_point}, \ | ||||
n_stride_src, \ | n_stride_src, \ | ||||
n_stride_dst, \ | n_stride_dst, \ | ||||
n, \ | n, \ | ||||
c, \ | |||||
hw}; \ | |||||
ic, \ | |||||
hw, \ | |||||
src_zero_point}; \ | |||||
auto kernel = relayout_kern<RelayoutProblem_>; \ | auto kernel = relayout_kern<RelayoutProblem_>; \ | ||||
int nr_threads = query_blocksize_for_kernel(kernel); \ | int nr_threads = query_blocksize_for_kernel(kernel); \ | ||||
nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \ | nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \ | ||||
@@ -1645,7 +1660,6 @@ void relayout_format::relayout_format_cuda_nchwx_nchw( | |||||
megdnn_assert(false, | megdnn_assert(false, | ||||
"Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", | "Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", | ||||
stype.name(), dtype.name(), h, w); | stype.name(), dtype.name(), h, w); | ||||
after_kernel_launch(); | |||||
} | } | ||||
void relayout_format::relayout_format_cuda_nchw4_nchw( | void relayout_format::relayout_format_cuda_nchw4_nchw( | ||||
@@ -21,6 +21,7 @@ | |||||
#include "cuda.h" | #include "cuda.h" | ||||
#include "src/cuda/cudnn_with_check.h" | #include "src/cuda/cudnn_with_check.h" | ||||
#include "cutlass/cutlass.h" | #include "cutlass/cutlass.h" | ||||
#include "cutlass/platform/platform.h" | |||||
#define cuda_check(_x) \ | #define cuda_check(_x) \ | ||||
do { \ | do { \ | ||||
@@ -448,13 +449,12 @@ MEGDNN_DEVICE __forceinline__ static int transform_int8_to_uint4x8( | |||||
template <bool signedness, typename T> | template <bool signedness, typename T> | ||||
MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(T storage, | MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(T storage, | ||||
int bits) { | int bits) { | ||||
uint8_t result = (uint8_t)((storage >> bits) & 0xf); | |||||
if (signedness) { | |||||
static constexpr uint8_t mask = (uint8_t)((1 << 4) - 1); | |||||
return (result & uint8_t(1 << 3)) ? ((int)(result) | ~(int)(mask)) | |||||
: (int)(result); | |||||
} | |||||
return int(result); | |||||
static constexpr int shift = 28; | |||||
using type = typename cutlass::platform::conditional<signedness, int, | |||||
unsigned>::type; | |||||
unsigned intermediate = static_cast<unsigned>(storage); | |||||
type result = reinterpret_cast<type&>(intermediate); | |||||
return (result << (shift - bits)) >> shift; | |||||
} | } | ||||
MEGDNN_DEVICE __forceinline__ static void transform_int4x8_to_int8( | MEGDNN_DEVICE __forceinline__ static void transform_int4x8_to_int8( | ||||
@@ -42,6 +42,36 @@ void recursive_cp(const TensorND& dst, const TensorND& src, size_t idx = 0, | |||||
} | } | ||||
} | } | ||||
template <size_t size_nbits> | |||||
void lowbit_recursive_cp(const TensorND& dst, const TensorND& src, | |||||
size_t idx = 0, size_t src_offset = 0, | |||||
size_t dst_offset = 0) { | |||||
MEGDNN_STATIC_ASSERT(!(8_z % size_nbits), | |||||
"size in bits of lowbit data type can only be 1, 2, 4 " | |||||
"or 8"); | |||||
if (idx < (src.layout.ndim - 1)) { | |||||
for (size_t i = 0; i < src.layout[idx]; ++i) { | |||||
lowbit_recursive_cp<size_nbits>( | |||||
dst, src, idx + 1, src_offset + i * src.layout.stride[idx], | |||||
dst_offset + i * dst.layout.stride[idx]); | |||||
} | |||||
} else { | |||||
megdnn_assert(src.layout.stride[idx] == 1); | |||||
megdnn_assert(dst.layout.stride[idx] == 1); | |||||
size_t dim_bytes = div_ceil(src.layout[idx], 8_z / size_nbits); | |||||
// offset in elements | |||||
uint8_t* dptr = reinterpret_cast<uint8_t*>(dst.raw_ptr) + | |||||
(dst_offset * size_nbits / 8); | |||||
uint8_t* sptr = reinterpret_cast<uint8_t*>(src.raw_ptr) + | |||||
(src_offset * size_nbits / 8); | |||||
for (size_t i = 0; i < dim_bytes; ++i) { | |||||
*dptr = *sptr; | |||||
dptr++; | |||||
sptr++; | |||||
} | |||||
} | |||||
} | |||||
void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src) { | void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src) { | ||||
switch (src.layout.dtype.enumv()) { | switch (src.layout.dtype.enumv()) { | ||||
#define cb(name, ctype) \ | #define cb(name, ctype) \ | ||||
@@ -54,10 +84,17 @@ void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src) { | |||||
cb(Int32, dt_int32); | cb(Int32, dt_int32); | ||||
cb(QuantizedS32, dt_int32); | cb(QuantizedS32, dt_int32); | ||||
cb(QuantizedS8, dt_qint8); | cb(QuantizedS8, dt_qint8); | ||||
#undef cb | |||||
#define cb(name, size_nbits) \ | |||||
case (DTypeEnum::name): { \ | |||||
lowbit_recursive_cp<size_nbits>(dst, src); \ | |||||
break; \ | |||||
} | |||||
cb(QuantizedS4, 4); | |||||
cb(Quantized4Asymm, 4); | |||||
#undef cb | |||||
default: | default: | ||||
megdnn_assert(0, "not support dtype %s", src.layout.dtype.name()); | megdnn_assert(0, "not support dtype %s", src.layout.dtype.name()); | ||||
#undef cb | |||||
} | } | ||||
} | } | ||||
@@ -66,24 +103,27 @@ void extract_from_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src, | |||||
megdnn_assert(dst.layout.is_contiguous() && src.layout.is_contiguous(), | megdnn_assert(dst.layout.is_contiguous() && src.layout.is_contiguous(), | ||||
"dst %s, src %s", dst.layout.to_string().c_str(), | "dst %s, src %s", dst.layout.to_string().c_str(), | ||||
src.layout.to_string().c_str()); | src.layout.to_string().c_str()); | ||||
const size_t type_size = dst.layout.dtype.size(); | |||||
const size_t n = dst.layout[0]; | const size_t n = dst.layout[0]; | ||||
const size_t n_stride_dst = dst.layout.stride[0]; | |||||
const size_t n_stride_src = src.layout.stride[0]; | |||||
const size_t n_stride_dst_in_bytes = | |||||
dst.layout.dtype.size(dst.layout.stride[0]); | |||||
const size_t n_stride_src_in_bytes = | |||||
src.layout.dtype.size(src.layout.stride[0]); | |||||
const size_t ocpg = dst.layout[1] / group; | const size_t ocpg = dst.layout[1] / group; | ||||
const size_t icpg = src.layout[1] / group; | const size_t icpg = src.layout[1] / group; | ||||
const size_t dst_hw = dst.layout[2] * dst.layout[3]; | |||||
const size_t src_hw = src.layout[2] * src.layout[3]; | |||||
megdnn_assert(dst_hw == src_hw); | |||||
const size_t dst_c_stride_in_bytes = | |||||
dst.layout.dtype.size(dst.layout.stride[1]); | |||||
const size_t src_c_stride_in_bytes = | |||||
src.layout.dtype.size(src.layout.stride[1]); | |||||
megdnn_assert(dst_c_stride_in_bytes == src_c_stride_in_bytes); | |||||
for (size_t nid = 0; nid < n; ++nid) { | for (size_t nid = 0; nid < n; ++nid) { | ||||
const size_t n_offset_dst = nid * n_stride_dst * type_size; | |||||
const size_t n_offset_src = nid * n_stride_src * type_size; | |||||
const size_t n_offset_dst = nid * n_stride_dst_in_bytes; | |||||
const size_t n_offset_src = nid * n_stride_src_in_bytes; | |||||
for (size_t gid = 0; gid < group; ++gid) { | for (size_t gid = 0; gid < group; ++gid) { | ||||
memcpy((char*)dst.raw_ptr + n_offset_dst + | memcpy((char*)dst.raw_ptr + n_offset_dst + | ||||
gid * ocpg * dst_hw * type_size, | |||||
gid * ocpg * dst_c_stride_in_bytes, | |||||
(char*)src.raw_ptr + n_offset_src + | (char*)src.raw_ptr + n_offset_src + | ||||
gid * icpg * src_hw * type_size, | |||||
ocpg * dst_hw * type_size); | |||||
gid * icpg * src_c_stride_in_bytes, | |||||
ocpg * dst_c_stride_in_bytes); | |||||
} | } | ||||
} | } | ||||
}; | }; | ||||
@@ -415,6 +455,30 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
return oc * ic * h * w * src.dtype.size(); | return oc * ic * h * w * src.dtype.size(); | ||||
} | } | ||||
case Param::Mode::NCHW_NCHW64: { | |||||
if (src[1] % 64 != 0) { | |||||
size_t n = src[0]; | |||||
size_t c = round_up(src[1], 64_z); | |||||
size_t h = src[2]; | |||||
size_t w = src[3]; | |||||
TensorLayout wsly({n, c, h, w}, src.dtype); | |||||
return wsly.span().dist_byte(); | |||||
} | |||||
return 0_z; | |||||
} | |||||
case Param::Mode::NCHW64_NCHW: { | |||||
if (param().oc != 0) { | |||||
size_t n = src[0]; | |||||
size_t c = src[1] * 64; | |||||
size_t h = src[2]; | |||||
size_t w = src[3]; | |||||
TensorLayout wsly({n, c, h, w}, dst.dtype); | |||||
return wsly.span().dist_byte(); | |||||
} | |||||
return 0_z; | |||||
} | |||||
default: | default: | ||||
return 0; | return 0; | ||||
} | } | ||||
@@ -437,6 +501,7 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
// clean dst | // clean dst | ||||
MEGDNN_DISPATCH_CPU_KERN( | MEGDNN_DISPATCH_CPU_KERN( | ||||
m_handle, memset(dst.raw_ptr, 0, dst.layout.span().dist_byte())); | m_handle, memset(dst.raw_ptr, 0, dst.layout.span().dist_byte())); | ||||
// pre | |||||
if (param().mode == Param::Mode::NCHW_NHWCD4I) { | if (param().mode == Param::Mode::NCHW_NHWCD4I) { | ||||
size_t N = src.layout[0]; | size_t N = src.layout[0]; | ||||
size_t IC = src.layout[1]; | size_t IC = src.layout[1]; | ||||
@@ -551,6 +616,27 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
cb2(2, 4, NCHW_NCHW4, group_src_layout, workspace_layout); | cb2(2, 4, NCHW_NCHW4, group_src_layout, workspace_layout); | ||||
} | } | ||||
} else if (param().mode == Param::Mode::NCHW_NCHW64) { | |||||
MIDOUT_BEGIN(megdnn_naive_relayout_format, | |||||
midout_iv(Param::Mode::NCHW_NCHW64)) { | |||||
size_t c = src.layout[1]; | |||||
if (c % 64 != 0) { | |||||
uint8_t zp = 0; | |||||
if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||||
zp = src.layout.dtype.param<dtype::Quantized4Asymm>() | |||||
.zero_point; | |||||
zp = (zp & 0xf) | (zp << 4); | |||||
} | |||||
MEGDNN_DISPATCH_CPU_KERN( | |||||
m_handle, memset(workspace.raw_ptr, zp, | |||||
exec_workspace.span().dist_byte())); | |||||
TensorND ws_nd(workspace.raw_ptr, exec_workspace); | |||||
MEGDNN_DISPATCH_CPU_KERN(m_handle, | |||||
padding_to_workspace(ws_nd, src);); | |||||
exec_src_nd.raw_ptr = workspace.raw_ptr; | |||||
} | |||||
} | |||||
MIDOUT_END(); | |||||
} else if (param().mode == | } else if (param().mode == | ||||
Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) { | Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) { | ||||
cb(1, 4, NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT); | cb(1, 4, NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT); | ||||
@@ -574,24 +660,16 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
cb(1, 2, 4, NCHW_NCHW4_WEIGHT); | cb(1, 2, 4, NCHW_NCHW4_WEIGHT); | ||||
} | } | ||||
} else if (param().mode == Param::Mode::NCHW4_NCHW) { | } else if (param().mode == Param::Mode::NCHW4_NCHW) { | ||||
if (exec_workspace.total_nr_elems() == dst.layout.total_nr_elems()) { | |||||
m_handle->relayout_opr()->exec( | |||||
exec_src_nd, {dst.raw_ptr, exec_workspace}, handle()); | |||||
return; | |||||
} else { | |||||
m_handle->relayout_opr()->exec( | |||||
exec_src_nd, {workspace.raw_ptr, exec_workspace}, handle()); | |||||
TensorLayout workspace_layout{{src.layout[0], src.layout[1] * 4, | |||||
src.layout[2], src.layout[3]}, | |||||
src.layout.dtype, | |||||
src.layout.format}; | |||||
extract_from_workspace(exec_dst_nd, | |||||
{workspace.raw_ptr, workspace_layout}, | |||||
param().group); | |||||
return; | |||||
if (exec_workspace.total_nr_elems() != dst.layout.total_nr_elems()) { | |||||
exec_dst_nd = {workspace.raw_ptr, exec_workspace}; | |||||
} | |||||
} else if (param().mode == Param::Mode::NCHW64_NCHW) { | |||||
if (exec_workspace.total_nr_elems() != dst.layout.total_nr_elems()) { | |||||
exec_dst_nd = {workspace.raw_ptr, exec_workspace}; | |||||
} | } | ||||
} | } | ||||
// do relayout | |||||
if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && | if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && | ||||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { | dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { | ||||
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; | TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; | ||||
@@ -600,7 +678,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
do_copy_diff_qu8_q8(dst, src); | do_copy_diff_qu8_q8(dst, src); | ||||
}; | }; | ||||
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); | MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); | ||||
return; | |||||
} else if (src.layout.dtype.enumv() == DTypeEnum::Uint8 && | } else if (src.layout.dtype.enumv() == DTypeEnum::Uint8 && | ||||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { | dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { | ||||
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; | TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; | ||||
@@ -609,7 +686,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
do_copy_diff_u8_q8(dst, src); | do_copy_diff_u8_q8(dst, src); | ||||
}; | }; | ||||
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); | MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); | ||||
return; | |||||
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 && | } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 && | ||||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { | dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { | ||||
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; | TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; | ||||
@@ -618,7 +694,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
do_copy_diff_q8_q8(dst, src); | do_copy_diff_q8_q8(dst, src); | ||||
}; | }; | ||||
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); | MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); | ||||
return; | |||||
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS32 && | } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS32 && | ||||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS32) { | dst.layout.dtype.enumv() == DTypeEnum::QuantizedS32) { | ||||
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; | TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; | ||||
@@ -627,7 +702,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
do_copy_diff_q32_q32(dst, src); | do_copy_diff_q32_q32(dst, src); | ||||
}; | }; | ||||
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); | MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); | ||||
return; | |||||
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4 && | } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4 && | ||||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS4) { | dst.layout.dtype.enumv() == DTypeEnum::QuantizedS4) { | ||||
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; | TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; | ||||
@@ -636,7 +710,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
do_copy_diff_q4_q4(dst, src); | do_copy_diff_q4_q4(dst, src); | ||||
}; | }; | ||||
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); | MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); | ||||
return; | |||||
} else if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm && | } else if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm && | ||||
dst.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { | dst.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { | ||||
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; | TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; | ||||
@@ -645,9 +718,20 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
do_copy_diff_qu4_qu4(dst, src); | do_copy_diff_qu4_qu4(dst, src); | ||||
}; | }; | ||||
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); | MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); | ||||
return; | |||||
} else { | } else { | ||||
m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); | m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); | ||||
} | |||||
// post | |||||
if (param().mode == Param::Mode::NCHW4_NCHW || | |||||
param().mode == Param::Mode::NCHW64_NCHW) { | |||||
if (exec_workspace.total_nr_elems() != dst.layout.total_nr_elems()) { | |||||
megdnn_assert(exec_workspace.dtype == dst.layout.dtype); | |||||
TensorND ws_nd{workspace.raw_ptr, exec_workspace}; | |||||
MEGDNN_DISPATCH_CPU_KERN( | |||||
m_handle, | |||||
extract_from_workspace(dst, ws_nd, param().group);); | |||||
} | |||||
} | } | ||||
#undef cb | #undef cb | ||||
} | } | ||||
@@ -18,7 +18,6 @@ | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace test; | using namespace test; | ||||
#define MEGDNN_WITH_BENCHMARK 1 | |||||
TEST_F(CUDA, RELAYOUT_FORMAT) { | TEST_F(CUDA, RELAYOUT_FORMAT) { | ||||
Checker<RelayoutFormat> checker(handle_cuda()); | Checker<RelayoutFormat> checker(handle_cuda()); | ||||
@@ -245,7 +244,7 @@ TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW64) { | |||||
param::RelayoutFormat param; | param::RelayoutFormat param; | ||||
param.mode = param::RelayoutFormat::Mode::NCHW_NCHW64; | param.mode = param::RelayoutFormat::Mode::NCHW_NCHW64; | ||||
for (size_t n : {1, 3}) { | for (size_t n : {1, 3}) { | ||||
for (size_t c : {64, 128}) { | |||||
for (size_t c : {15, 64, 128}) { | |||||
for (size_t h : {7, 14, 16, 28}) { | for (size_t h : {7, 14, 16, 28}) { | ||||
for (size_t w : {2, 3, 7, 8, 16, 31}) { | for (size_t w : {2, 3, 7, 8, 16, 31}) { | ||||
checker.set_dtype(0, dtype::QuantizedS4{2.f}) | checker.set_dtype(0, dtype::QuantizedS4{2.f}) | ||||
@@ -285,36 +284,41 @@ TEST_F(CUDA, RELAYOUT_FORMAT_NCHW64_NCHW) { | |||||
param::RelayoutFormat param; | param::RelayoutFormat param; | ||||
param.mode = param::RelayoutFormat::Mode::NCHW64_NCHW; | param.mode = param::RelayoutFormat::Mode::NCHW64_NCHW; | ||||
for (size_t n : {1, 3}) { | for (size_t n : {1, 3}) { | ||||
for (size_t c : {64, 128}) { | |||||
for (size_t c : {15, 64, 128}) { | |||||
for (size_t h : {7, 14, 16, 28}) { | for (size_t h : {7, 14, 16, 28}) { | ||||
for (size_t w : {2, 3, 4, 7, 14, 16, 17}) { | for (size_t w : {2, 3, 4, 7, 14, 16, 17}) { | ||||
if (c % 64 != 0) { | |||||
param.oc = c; | |||||
} else { | |||||
param.oc = 0; | |||||
} | |||||
checker.set_dtype(0, dtype::QuantizedS4{2.f}) | checker.set_dtype(0, dtype::QuantizedS4{2.f}) | ||||
.set_dtype(1, dtype::QuantizedS4{2.f}) | .set_dtype(1, dtype::QuantizedS4{2.f}) | ||||
.set_rng(0, &s4) | .set_rng(0, &s4) | ||||
.set_param(param) | .set_param(param) | ||||
.set_epsilon(1e-3) | .set_epsilon(1e-3) | ||||
.execs({{n, c / 64, h, w, 64}, {}}); | |||||
.execs({{n, (c + 63) / 64, h, w, 64}, {}}); | |||||
checker.set_dtype(0, dtype::Quantized4Asymm{1.2f, 4}) | checker.set_dtype(0, dtype::Quantized4Asymm{1.2f, 4}) | ||||
.set_dtype(1, dtype::Quantized4Asymm{1.2f, 8}) | .set_dtype(1, dtype::Quantized4Asymm{1.2f, 8}) | ||||
.set_rng(0, &u4) | .set_rng(0, &u4) | ||||
.set_param(param) | .set_param(param) | ||||
.set_epsilon(1e-3) | .set_epsilon(1e-3) | ||||
.execs({{n, c / 64, h, w, 64}, {}}); | |||||
.execs({{n, (c + 63) / 64, h, w, 64}, {}}); | |||||
checker.set_dtype(0, dtype::QuantizedS4{1.19990307f}) | checker.set_dtype(0, dtype::QuantizedS4{1.19990307f}) | ||||
.set_dtype(1, dtype::QuantizedS4{1.f}) | .set_dtype(1, dtype::QuantizedS4{1.f}) | ||||
.set_rng(0, &s4) | .set_rng(0, &s4) | ||||
.set_param(param) | .set_param(param) | ||||
.set_epsilon(1e-3) | .set_epsilon(1e-3) | ||||
.execs({{n, c / 64, h, w, 64}, {}}); | |||||
.execs({{n, (c + 63) / 64, h, w, 64}, {}}); | |||||
checker.set_dtype(0, dtype::Quantized4Asymm{1.20211209f, 8}) | checker.set_dtype(0, dtype::Quantized4Asymm{1.20211209f, 8}) | ||||
.set_dtype(1, dtype::Quantized4Asymm{1.f, 4}) | .set_dtype(1, dtype::Quantized4Asymm{1.f, 4}) | ||||
.set_rng(0, &u4) | .set_rng(0, &u4) | ||||
.set_param(param) | .set_param(param) | ||||
.set_epsilon(1e-3) | .set_epsilon(1e-3) | ||||
.execs({{n, c / 64, h, w, 64}, {}}); | |||||
.execs({{n, (c + 63) / 64, h, w, 64}, {}}); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -375,10 +379,14 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) { | |||||
CUBenchmarker<RelayoutFormat> benchmarker(handle_cuda()); | CUBenchmarker<RelayoutFormat> benchmarker(handle_cuda()); | ||||
benchmarker.set_param(param); | benchmarker.set_param(param); | ||||
benchmarker.set_dtype(0, dtype::QuantizedS4{1.19990307f}) | benchmarker.set_dtype(0, dtype::QuantizedS4{1.19990307f}) | ||||
.set_dtype(1, dtype::QuantizedS4{1.20210322f}); | |||||
.set_dtype(1, dtype::QuantizedS4{1.19990307f}); | |||||
for (auto&& shape : shapes) { | for (auto&& shape : shapes) { | ||||
double memaccess = double(shape.total_nr_elems()) * 1e-6; | |||||
double memaccess = | |||||
double(TensorLayout(shape, dtype::QuantizedS4{1.f}) | |||||
.span() | |||||
.dist_byte()) * | |||||
2e-6; | |||||
auto time_ms = benchmarker.execs({shape, {}}); | auto time_ms = benchmarker.execs({shape, {}}); | ||||
printf("execute %s, time %.4f ms, %.4f GB/s\n", | printf("execute %s, time %.4f ms, %.4f GB/s\n", | ||||
shape.to_string().c_str(), time_ms, memaccess / time_ms); | shape.to_string().c_str(), time_ms, memaccess / time_ms); | ||||
@@ -387,8 +395,9 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) { | |||||
{ | { | ||||
TensorShapeArray shapes = { | TensorShapeArray shapes = { | ||||
{1, 64, 56, 56}, {16, 64, 56, 56}, {64, 64, 56, 56}, | |||||
{1, 64, 56, 55}, {16, 64, 56, 55}, {64, 64, 56, 55}, | |||||
{1, 64, 56, 56}, {16, 64, 56, 56}, {64, 64, 56, 56}, | |||||
{1, 64, 56, 55}, {16, 64, 56, 55}, {64, 64, 56, 55}, | |||||
{1, 256, 384, 640}, | |||||
}; | }; | ||||
Param param; | Param param; | ||||
param.mode = param::RelayoutFormat::Mode::NCHW_NCHW64; | param.mode = param::RelayoutFormat::Mode::NCHW_NCHW64; | ||||
@@ -399,7 +408,8 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) { | |||||
{64, 1, 56, 56, 64}, | {64, 1, 56, 56, 64}, | ||||
{1, 32, 7, 7, 64}, | {1, 32, 7, 7, 64}, | ||||
{16, 32, 7, 7, 64}, | {16, 32, 7, 7, 64}, | ||||
{64, 32, 7, 7, 64}, | |||||
{64, 32, 7, 7, 64}, | |||||
{1, 4, 384, 640, 64}, | |||||
}; | }; | ||||
Param param; | Param param; | ||||
param.mode = param::RelayoutFormat::Mode::NCHW64_NCHW; | param.mode = param::RelayoutFormat::Mode::NCHW64_NCHW; | ||||