GitOrigin-RevId: 1445ecfabe
release-1.5
@@ -196,6 +196,32 @@ public: | |||
const TensorLayout& layout) const override; | |||
}; | |||
using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; | |||
///*! | |||
// * \brief used for tensors with lowbit data type | |||
// * | |||
// * \p SIZE_NBITS is the size in bits of element of the tensor. | |||
// * | |||
// */ | |||
//template <size_t SIZE_NBITS_> | |||
//class LowbitTensorFormat : public TensorFormat::ImplBase { | |||
// static constexpr size_t SIZE_NBITS = SIZE_NBITS_; | |||
// size_t m_align_size_in_bits; | |||
// | |||
//protected: //? | |||
// LowbitTensorFormat(Type type, size_t m_align_size_in_bits); | |||
// | |||
//public: | |||
// size_t align_size_in_bits() const { | |||
// return m_align_size_in_bits; | |||
// } | |||
// | |||
// std::string to_string() const override; | |||
// | |||
// void serialize_append( | |||
// | |||
// | |||
//}; | |||
} // namespace detail | |||
/*! | |||
@@ -895,6 +895,7 @@ Relayout mode. | |||
* ``NCHW4`` layout: ``{N, C/4, H, W, 4}`` | |||
* ``NCHW88`` layout: ``{N, C/8, H, W, 8}`` | |||
* ``CHWN4`` layout: ``{C/4, H, W, N, 4}`` | |||
* ``NCHW64`` layout: ``{N, C/64, H, W, 64}`` | |||
**Float weight transformation definitions** | |||
@@ -969,6 +970,8 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o | |||
'NCHW_NCHW4', | |||
'NCHW4_NCHW', | |||
'NCHW_NCHW4_WEIGHT', | |||
'NCHW_NCHW64', | |||
'NCHW64_NCHW', | |||
) | |||
) | |||
@@ -251,6 +251,23 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, | |||
dst[3] = src[3]; | |||
megdnn_assert(dst[1] % param().group == 0); | |||
break; | |||
case Param::Mode::NCHW_NCHW64: | |||
megdnn_assert(src.ndim == 4 && (src[1] % 64) == 0); | |||
dst.ndim = 5; | |||
dst[0] = src[0]; | |||
dst[1] = src[1] / 64; | |||
dst[2] = src[2]; | |||
dst[3] = src[3]; | |||
dst[4] = 64; | |||
break; | |||
case Param::Mode::NCHW64_NCHW: | |||
megdnn_assert(src.ndim == 5); | |||
dst.ndim = 4; | |||
dst[0] = src[0]; | |||
dst[1] = src[1] * 64; | |||
dst[2] = src[2]; | |||
dst[3] = src[3]; | |||
break; | |||
default: | |||
megdnn_assert(0, "Invalid RelayoutFormat Mode"); | |||
break; | |||
@@ -352,7 +369,12 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { | |||
CHECK_SRC(DefaultTensorFormat::make()); | |||
dst = src; | |||
break; | |||
case Param::Mode::NCHW_NCHW64: | |||
dst = src; | |||
break; | |||
case Param::Mode::NCHW64_NCHW: | |||
dst = src; | |||
break; | |||
default: | |||
megdnn_throw("Invalid relayout format mode"); | |||
break; | |||
@@ -633,6 +655,19 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, | |||
exec_src = src.dimshuffle({3, 0, 1, 2, 4}); | |||
exec_dst = dst; | |||
break; | |||
case Param::Mode::NCHW_NCHW64: | |||
// src is {N, C, H, W} | |||
// dst is {N, C/64, H, W, 64} | |||
exec_src = src.reshape({src[0], src[1] / 64, 64, src[2], src[3]}) | |||
.dimshuffle({0, 1, 3, 4, 2}); | |||
exec_dst = dst; | |||
break; | |||
case Param::Mode::NCHW64_NCHW: | |||
// src is {N, C/64, H, W, 64} | |||
// dst is {N, C, H, W} | |||
exec_src = src.dimshuffle({0, 1, 4, 2, 3}); | |||
exec_dst = dst; | |||
break; | |||
default: | |||
megdnn_assert(0, "Invalid RelayoutFormat Mode"); | |||
} | |||
@@ -69,12 +69,9 @@ size_t ConvBiasForwardImpl::AlgoFallbackNCHWQS4::get_workspace_in_bytes( | |||
void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec( | |||
const ExecArgs& args) const { | |||
using Format = Param::Format; | |||
auto&& param = args.opr->param(); | |||
auto&& fm = args.filter_meta; | |||
auto layouts = make_underlying_tensor_layout( | |||
*(args.src_layout), fm, *(args.bias_layout), *(args.z_layout), | |||
*(args.dst_layout)); | |||
*(args.src_layout), *(args.filter_layout), *(args.bias_layout), | |||
*(args.z_layout), *(args.dst_layout)); | |||
auto ws = get_workspace_bundle(args.workspace.raw_ptr, args); | |||
auto ws_src = ws.get(0); | |||
auto ws_filter = ws.get(1); | |||
@@ -82,20 +79,27 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec( | |||
void* ws_z = nullptr; | |||
if (args.z_layout->ndim > 0) | |||
ws_z = ws.get(4); | |||
auto&& stream = cuda_stream(args.opr->handle()); | |||
auto nchw2nchw64 = [](const TensorND& src, void* raw_dptr) { | |||
if (raw_dptr == nullptr) | |||
// auto&& stream = cuda_stream(args.opr->handle()); | |||
auto nchw2nchw64 = [&args](const TensorND& src, TensorND&& dst) { | |||
if (dst.raw_ptr == nullptr) | |||
return; | |||
auto relayout = args.handle->create_operator<RelayoutFormat>(); | |||
relayout->param() = RelayoutFormat::Param::Mode::NCHW_NCHW64; | |||
Workspace dummy; | |||
relayout->exec(src, dst, dummy); | |||
}; | |||
auto nchw642nchw = [](const TensorND& src, void* raw_dptr) { | |||
auto nchw642nchw = [&args](const TensorND& src, TensorND&& dst) { | |||
auto relayout = args.handle->create_operator<RelayoutFormat>(); | |||
relayout->param() = RelayoutFormat::Param::Mode::NCHW64_NCHW; | |||
Workspace dummy; | |||
relayout->exec(src, dst, dummy); | |||
}; | |||
// reformat src | |||
nchw2nchw64(*(args.src_tensor), ws_src); | |||
nchw2nchw64(*(args.src_tensor), {ws_src, layouts[0]}); | |||
// reformat filter | |||
nchw2nchw64(*(args.filter_tensor), ws_filter); | |||
nchw2nchw64(*(args.filter_tensor), {ws_filter, layouts[1]}); | |||
// reformat z | |||
nchw2nchw64(*(args.z_tensor), ws_z); | |||
nchw2nchw64(*(args.z_tensor), {ws_z, layouts[3]}); | |||
TensorND src_{ws_src, layouts[0]}, filter_{ws_filter, layouts[1]}, | |||
bias_{args.bias_tensor->raw_ptr, layouts[2]}, z_{ws_z, layouts[3]}, | |||
dst_{ws_dst, layouts[4]}; | |||
@@ -109,22 +113,22 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec( | |||
args.preprocessed_filter}; | |||
m_underlying_algo.exec(args); | |||
// reformat dst | |||
nchw642nchw(dst_, args.dst_tensor->raw_ptr); | |||
nchw642nchw(dst_, {args.dst_tensor->raw_ptr, args.dst_tensor->layout}); | |||
} | |||
SmallVector<TensorLayout> | |||
ConvBiasForwardImpl::AlgoFallbackNCHWQS4::make_underlying_tensor_layout( | |||
const TensorLayout& src, const CanonizedFilterMeta& filter_meta, | |||
const TensorLayout& src, const TensorLayout& filter, | |||
const TensorLayout& bias, const TensorLayout& z, | |||
const TensorLayout& dst) const { | |||
size_t n = src[0], ci = src[1], hi = src[2], wi = src[3]; | |||
size_t co = dst[1], ho = dst[2], wo = dst[3]; | |||
size_t fh = filter_meta.spatial[0], fw = filter_meta.spatial[1]; | |||
size_t fh = filter[2], fw = filter[3]; | |||
SmallVector<TensorLayout> rst; | |||
rst.emplace_back(TensorLayout{{n, ci / 64, hi, wi, 64}, src.dtype}); | |||
rst.emplace_back(TensorLayout{{co, ci / 64, fh, fw, 64}, filter.dtype}); | |||
rst.emplace_back(TensorLayout{{1, co / 64, 1, 1, 64}, bias.dtype}); | |||
if (z.layout.ndim > 0) { | |||
if (z.ndim > 0) { | |||
rst.emplace_back(TensorLayout{{n, co / 64, ho, wo, 64}, z.dtype}); | |||
} else { | |||
rst.emplace_back(TensorLayout{}); | |||
@@ -134,15 +138,13 @@ ConvBiasForwardImpl::AlgoFallbackNCHWQS4::make_underlying_tensor_layout( | |||
} | |||
WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS4::get_workspace_bundle( | |||
void* ptr, const SizeArgs& args) const { | |||
void* raw_ptr, const SizeArgs& args) const { | |||
size_t ws_size_src = args.src_layout->span().dist_byte(); | |||
size_t ws_size_filter = args.filter_layout->span().dist_byte(); | |||
size_t ws_size_dst = args.dst_layout->span().dist_byte(); | |||
auto&& param = args.opr->param(); | |||
auto&& fm = args.filter_meta; | |||
auto layouts = make_underlying_tensor_layout( | |||
*(args.src_layout), fm, *(args.bias_layout), *(args.z_layout), | |||
*(args.dst_layout)); | |||
*(args.src_layout), *(args.filter_layout), *(args.bias_layout), | |||
*(args.z_layout), *(args.dst_layout)); | |||
SizeArgs args_{args.opr, | |||
layouts[0], | |||
layouts[1], | |||
@@ -78,29 +78,33 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
return handle()->create_operator<RelayoutForward>()->exec( | |||
{src.raw_ptr, exec_src_layout}, {dst.raw_ptr, exec_dst_layout}); | |||
} | |||
if (param().mode == Param::Mode::NCHW_NCHW4 || | |||
param().mode == Param::Mode::NCHW4_NCHW || | |||
param().mode == Param::Mode::NCHW_NCHW4_WEIGHT) { | |||
bool is_trans_4bits = (param().mode == Param::Mode::NCHW_NCHW64 || | |||
param().mode == Param::Mode::NCHW64_NCHW) && | |||
(src_dtype.enumv() == DTypeEnum::QuantizedS4 || | |||
src_dtype.enumv() == DTypeEnum::Quantized4Asymm); | |||
bool is_nchw_nchw4 = param().mode == Param::Mode::NCHW_NCHW4 || | |||
param().mode == Param::Mode::NCHW4_NCHW || | |||
param().mode == Param::Mode::NCHW_NCHW4_WEIGHT; | |||
if (is_trans_4bits || is_nchw_nchw4) { | |||
bool is_usable = relayout_format::RelayoutFormatFast::usable( | |||
src.layout, dst.layout); | |||
megdnn_assert(is_usable, | |||
"RelayoutFormatNCHW_NCHW4 kernel not usable for %s(%s) " | |||
"to %s(%s)", | |||
"RelayoutFormatFast kernel is not usable for " | |||
"transforming %s(%s) to %s(%s).", | |||
src.layout.to_string().c_str(), src.layout.dtype.name(), | |||
dst.layout.to_string().c_str(), dst.layout.dtype.name()); | |||
relayout_format::RelayoutFormatFast::exec(src, dst, | |||
cuda_stream(this->handle()), | |||
param().mode, param().group); | |||
} else { | |||
TensorLayout exec_src, exec_dst, exec_workspace; | |||
deduce_exec_layout(src.layout, dst.layout, exec_workspace, exec_src, | |||
exec_dst); | |||
TensorND exec_src_nd{src.raw_ptr, exec_src}; | |||
TensorND exec_dst_nd{dst.raw_ptr, exec_dst}; | |||
handle()->create_operator<RelayoutForward>()->exec(exec_src_nd, | |||
exec_dst_nd); | |||
return relayout_format::RelayoutFormatFast::exec( | |||
src, dst, cuda_stream(this->handle()), param().mode, | |||
param().group); | |||
} | |||
// fallback impls | |||
TensorLayout exec_src, exec_dst, exec_workspace; | |||
deduce_exec_layout(src.layout, dst.layout, exec_workspace, exec_src, | |||
exec_dst); | |||
TensorND exec_src_nd{src.raw_ptr, exec_src}; | |||
TensorND exec_dst_nd{dst.raw_ptr, exec_dst}; | |||
handle()->create_operator<RelayoutForward>()->exec(exec_src_nd, | |||
exec_dst_nd); | |||
} | |||
size_t RelayoutFormatImpl::get_workspace_in_bytes( | |||
@@ -24,6 +24,8 @@ inline void get_scale_zeropoint(const DType& tensor_dtype, float& scale, | |||
scale = tensor_dtype.param<dtype::Quantized8Asymm>().scale; | |||
} else if (tensor_dtype.enumv() == DTypeEnum::QuantizedS8) { | |||
scale = tensor_dtype.param<dtype::QuantizedS8>().scale; | |||
} else if (tensor_dtype.enumv() == DTypeEnum::QuantizedS4) { | |||
scale = tensor_dtype.param<dtype::QuantizedS4>().scale; | |||
} | |||
} | |||
@@ -39,9 +41,8 @@ void relayout_format::RelayoutFormatFast::exec(const TensorND& src, | |||
cudaStream_t stream, | |||
RelayoutFormat::Param::Mode mode, | |||
int group) { | |||
size_t ih = src.layout[2]; | |||
size_t iw = src.layout[3]; | |||
size_t hw = ih * iw; | |||
auto&& stype = src.layout.dtype; | |||
auto&& dtype = dst.layout.dtype; | |||
float src_scale = 1.f; | |||
float dst_scale = 1.f; | |||
uint8_t src_zero_point = 0; | |||
@@ -51,22 +52,28 @@ void relayout_format::RelayoutFormatFast::exec(const TensorND& src, | |||
if (src.layout.dtype.enumv() == DTypeEnum::Uint8) { | |||
src_zero_point = 128; | |||
} | |||
if (mode == RelayoutFormat::Param::Mode::NCHW_NCHW4) { | |||
if (hw % 4 == 0) { | |||
relayout_format_cuda_nchw_nchw4<4>(src, dst, stream, src_scale, | |||
if (mode == RelayoutFormat::Param::Mode::NCHW_NCHW4 || | |||
mode == RelayoutFormat::Param::Mode::NCHW_NCHW64) { | |||
return relayout_format_cuda_nchw_nchwx(src, dst, stream, src_scale, | |||
dst_scale, src_zero_point, | |||
dst_zero_point, group); | |||
} else { | |||
relayout_format_cuda_nchw_nchw4<1>(src, dst, stream, src_scale, | |||
} else if (mode == RelayoutFormat::Param::Mode::NCHW64_NCHW) { | |||
megdnn_assert(group == 1, | |||
"RelayoutFormat kernel only support transforming NCHW64 " | |||
"to NCHW with group = 1(group:%d)", | |||
group); | |||
return relayout_format_cuda_nchwx_nchw(src, dst, stream, src_scale, | |||
dst_scale, src_zero_point, | |||
dst_zero_point, group); | |||
} | |||
dst_zero_point); | |||
} else if (mode == RelayoutFormat::Param::Mode::NCHW_NCHW4_WEIGHT) { | |||
relayout_format_cuda_nchw_nchw4_weight(src, dst, stream); | |||
return relayout_format_cuda_nchw_nchw4_weight(src, dst, stream); | |||
} else if (mode == RelayoutFormat::Param::Mode::NCHW4_NCHW) { | |||
relayout_format_cuda_nchw4_nchw(src, dst, stream, group); | |||
return relayout_format_cuda_nchw4_nchw(src, dst, stream, group); | |||
} else { | |||
megdnn_throw("only support nchw_nchw4 nchw4_nchw layout_format"); | |||
megdnn_throw( | |||
"only support nchw_nchw64/nchw64_nchw/nchw_nchw4/nchw4_nchw " | |||
"layout_format"); | |||
} | |||
} | |||
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -19,14 +19,11 @@ namespace megdnn { | |||
namespace cuda { | |||
namespace relayout_format { | |||
template <int pack_w = 1> | |||
void relayout_format_cuda_nchw_nchw4(const TensorND& src, const TensorND& dst, | |||
const cudaStream_t& stream, | |||
const float src_scale = 1.f, | |||
const float dst_scale = 1.f, | |||
const uint8_t src_zero_point = 0, | |||
const uint8_t dst_zero_point = 0, | |||
const int group = 1); | |||
void relayout_format_cuda_nchw_nchwx( | |||
const TensorND& src, const TensorND& dst, const cudaStream_t& stream, | |||
const float src_scale = 1.f, const float dst_scale = 1.f, | |||
const uint8_t src_zero_point = 0, const uint8_t dst_zero_point = 0, | |||
const int group = 1); | |||
bool relayout_format_cuda_usable(const TensorLayout& src_layout, | |||
const TensorLayout& dst_layout); | |||
@@ -35,6 +32,13 @@ void relayout_format_cuda_nchw4_nchw(const TensorND& src, const TensorND& dst, | |||
const cudaStream_t& stream, | |||
const int group); | |||
void relayout_format_cuda_nchwx_nchw(const TensorND& src, const TensorND& dst, | |||
const cudaStream_t& stream, | |||
const float src_scale = 1.f, | |||
const float dst_scale = 1.f, | |||
const uint8_t src_zero_point = 0, | |||
const uint8_t dst_zero_point = 0); | |||
void relayout_format_cuda_nchw_nchw4_weight(const TensorND& src, | |||
const TensorND& dst, | |||
const cudaStream_t& stream); | |||
@@ -110,6 +110,12 @@ MEGDNN_NORETURN void report_error(const char* msg); | |||
template <typename T, size_t N> | |||
struct array_wrapper { | |||
T data[N]; | |||
MEGDNN_DEVICE __forceinline__ T& operator[](size_t pos) { | |||
return reinterpret_cast<T&>(data[pos]); | |||
} | |||
MEGDNN_DEVICE __forceinline__ T const& operator[](size_t pos) const { | |||
return reinterpret_cast<T const&>(data[pos]); | |||
} | |||
}; | |||
/*! | |||
@@ -207,12 +213,29 @@ struct CudaDTypeParamImpl<dt_quint4> : DTypeParamImpl<dt_quint4> { | |||
CudaDTypeParamImpl(const DTypeParamImpl<dt_quint4>& param) | |||
: CudaDTypeParamImpl(param.scale, param.zero_point) {} | |||
__device__ uint8_t quantize(float in) const { | |||
__device__ dt_quint4 quantize(float in) const { | |||
float v = in * inv_scale; | |||
v = roundf(v); | |||
v = v + zero_point; | |||
v = fmin(fmax(0.f, v), 15.f); | |||
return static_cast<uint8_t>(v); | |||
return static_cast<dt_quint4>(v); | |||
} | |||
}; | |||
template <> | |||
struct CudaDTypeParamImpl<dt_qint4> : DTypeParamImpl<dt_qint4> { | |||
float inv_scale; | |||
CudaDTypeParamImpl() = default; | |||
CudaDTypeParamImpl(float scale) | |||
: DTypeParamImpl<dt_qint4>(scale), inv_scale(1.0f / scale) {} | |||
CudaDTypeParamImpl(const DTypeParamImpl<dt_qint4>& param) | |||
: CudaDTypeParamImpl(param.scale) {} | |||
__device__ dt_qint4 quantize(float in) const { | |||
float v = in * inv_scale; | |||
v = roundf(v); | |||
v = fmin(fmax(-8.f, v), 7.f); | |||
return static_cast<dt_qint4>(v); | |||
} | |||
}; | |||
@@ -351,6 +374,110 @@ MEGDNN_DEVICE __forceinline__ static float4 operator+(float4 lval, | |||
return make_float4(lval.x + rval.x, lval.y + rval.y, lval.z + rval.z, | |||
lval.w + rval.w); | |||
} | |||
MEGDNN_DEVICE __forceinline__ static int transform_int8_to_int4x8( | |||
int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { | |||
unsigned out; | |||
#if __CUDA_ARCH__ >= 750 | |||
asm volatile( | |||
"{ .reg .u32 r4;" | |||
"cvt.pack.sat.s4.s32.b32 r4, %8, %7, 0;" | |||
"cvt.pack.sat.s4.s32.b32 r4, %6, %5, r4;" | |||
"cvt.pack.sat.s4.s32.b32 r4, %4, %3, r4;" | |||
"cvt.pack.sat.s4.s32.b32 %0, %2, %1, r4;" | |||
"}" | |||
: "=r"(out) | |||
: "r"(s0), "r"(s1), "r"(s2), "r"(s3), "r"(s4), "r"(s5), "r"(s6), | |||
"r"(s7)); | |||
#else | |||
#define CVT_SAT_S4_S32(r, bits) \ | |||
r = r <= -8 ? -8 : r; \ | |||
r = r > 7 ? 7 : r; \ | |||
r = (((unsigned)r & 0xf) << bits); | |||
CVT_SAT_S4_S32(s0, 0) | |||
CVT_SAT_S4_S32(s1, 4) | |||
CVT_SAT_S4_S32(s2, 8) | |||
CVT_SAT_S4_S32(s3, 12) | |||
CVT_SAT_S4_S32(s4, 16) | |||
CVT_SAT_S4_S32(s5, 20) | |||
CVT_SAT_S4_S32(s6, 24) | |||
CVT_SAT_S4_S32(s7, 28) | |||
out = s0 + s1 + s2 + s3 + s4 + s5 + s6 + s7; | |||
#undef CVT_SAT_S4_S32 | |||
#endif | |||
return reinterpret_cast<int const&>(out); | |||
} | |||
MEGDNN_DEVICE __forceinline__ static int transform_int8_to_uint4x8( | |||
int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { | |||
unsigned out; | |||
#if __CUDA_ARCH__ >= 750 | |||
asm volatile( | |||
"{ .reg .u32 r4;" | |||
"cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;" | |||
"cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;" | |||
"cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;" | |||
"cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;" | |||
"}" | |||
: "=r"(out) | |||
: "r"(s0), "r"(s1), "r"(s2), "r"(s3), "r"(s4), "r"(s5), "r"(s6), | |||
"r"(s7)); | |||
#else | |||
#define CVT_SAT_U4_S32(r, bits) \ | |||
r = r <= 0 ? 0 : r; \ | |||
r = r > 15 ? 15 : r; \ | |||
r = (((unsigned)r & 0xf) << bits); | |||
CVT_SAT_U4_S32(s0, 0) | |||
CVT_SAT_U4_S32(s1, 4) | |||
CVT_SAT_U4_S32(s2, 8) | |||
CVT_SAT_U4_S32(s3, 12) | |||
CVT_SAT_U4_S32(s4, 16) | |||
CVT_SAT_U4_S32(s5, 20) | |||
CVT_SAT_U4_S32(s6, 24) | |||
CVT_SAT_U4_S32(s7, 28) | |||
out = s0 + s1 + s2 + s3 + s4 + s5 + s6 + s7; | |||
#undef CVT_SAT_U4_S32 | |||
#endif | |||
return reinterpret_cast<int const&>(out); | |||
} | |||
template <bool signedness> | |||
MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(unsigned storage, | |||
unsigned bits); | |||
template <> | |||
MEGDNN_DEVICE __forceinline__ int unpack_integer_4bits<true>(unsigned storage, | |||
unsigned bits) { | |||
uint8_t result = (uint8_t)((unsigned)(storage >> bits) & 0xf); | |||
static constexpr uint8_t mask = (uint8_t)((1 << 4) - 1); | |||
return (result & uint8_t(1 << 3)) ? ((int)(result) | ~(int)(mask)) | |||
: (int)(result); | |||
} | |||
template <> | |||
MEGDNN_DEVICE __forceinline__ int unpack_integer_4bits<false>(unsigned storage, | |||
unsigned bits) { | |||
uint8_t result = (uint8_t)((unsigned)(storage >> bits) & 0xf); | |||
return (int)(result); | |||
} | |||
MEGDNN_DEVICE __forceinline__ static void transform_int4x8_to_int8( | |||
int (&result)[8], const int& source) { | |||
#pragma unroll | |||
for (int i = 0; i < 8; i++) { | |||
result[i] = unpack_integer_4bits<true>( | |||
reinterpret_cast<unsigned const&>(source), (i << 2)); | |||
} | |||
} | |||
MEGDNN_DEVICE __forceinline__ static void transform_uint4x8_to_int8( | |||
int (&result)[8], const int& source) { | |||
#pragma unroll | |||
for (int i = 0; i < 8; i++) { | |||
result[i] = unpack_integer_4bits<false>( | |||
reinterpret_cast<unsigned const&>(source), (i << 2)); | |||
} | |||
} | |||
#endif | |||
} // namespace cuda | |||
} // namespace megdnn | |||