GitOrigin-RevId: 1445ecfabe
release-1.5
@@ -196,6 +196,32 @@ public: | |||||
const TensorLayout& layout) const override; | const TensorLayout& layout) const override; | ||||
}; | }; | ||||
using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; | using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; | ||||
///*! | |||||
// * \brief used for tensors with lowbit data type | |||||
// * | |||||
// * \p SIZE_NBITS is the size in bits of element of the tensor. | |||||
// * | |||||
// */ | |||||
//template <size_t SIZE_NBITS_> | |||||
//class LowbitTensorFormat : public TensorFormat::ImplBase { | |||||
// static constexpr size_t SIZE_NBITS = SIZE_NBITS_; | |||||
// size_t m_align_size_in_bits; | |||||
// | |||||
//protected: //? | |||||
// LowbitTensorFormat(Type type, size_t m_align_size_in_bits); | |||||
// | |||||
//public: | |||||
// size_t align_size_in_bits() const { | |||||
// return m_align_size_in_bits; | |||||
// } | |||||
// | |||||
// std::string to_string() const override; | |||||
// | |||||
// void serialize_append( | |||||
// | |||||
// | |||||
//}; | |||||
} // namespace detail | } // namespace detail | ||||
/*! | /*! | ||||
@@ -895,6 +895,7 @@ Relayout mode. | |||||
* ``NCHW4`` layout: ``{N, C/4, H, W, 4}`` | * ``NCHW4`` layout: ``{N, C/4, H, W, 4}`` | ||||
* ``NCHW88`` layout: ``{N, C/8, H, W, 8}`` | * ``NCHW88`` layout: ``{N, C/8, H, W, 8}`` | ||||
* ``CHWN4`` layout: ``{C/4, H, W, N, 4}`` | * ``CHWN4`` layout: ``{C/4, H, W, N, 4}`` | ||||
* ``NCHW64`` layout: ``{N, C/64, H, W, 64}`` | |||||
**Float weight transformation definitions** | **Float weight transformation definitions** | ||||
@@ -969,6 +970,8 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o | |||||
'NCHW_NCHW4', | 'NCHW_NCHW4', | ||||
'NCHW4_NCHW', | 'NCHW4_NCHW', | ||||
'NCHW_NCHW4_WEIGHT', | 'NCHW_NCHW4_WEIGHT', | ||||
'NCHW_NCHW64', | |||||
'NCHW64_NCHW', | |||||
) | ) | ||||
) | ) | ||||
@@ -251,6 +251,23 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, | |||||
dst[3] = src[3]; | dst[3] = src[3]; | ||||
megdnn_assert(dst[1] % param().group == 0); | megdnn_assert(dst[1] % param().group == 0); | ||||
break; | break; | ||||
case Param::Mode::NCHW_NCHW64: | |||||
megdnn_assert(src.ndim == 4 && (src[1] % 64) == 0); | |||||
dst.ndim = 5; | |||||
dst[0] = src[0]; | |||||
dst[1] = src[1] / 64; | |||||
dst[2] = src[2]; | |||||
dst[3] = src[3]; | |||||
dst[4] = 64; | |||||
break; | |||||
case Param::Mode::NCHW64_NCHW: | |||||
megdnn_assert(src.ndim == 5); | |||||
dst.ndim = 4; | |||||
dst[0] = src[0]; | |||||
dst[1] = src[1] * 64; | |||||
dst[2] = src[2]; | |||||
dst[3] = src[3]; | |||||
break; | |||||
default: | default: | ||||
megdnn_assert(0, "Invalid RelayoutFormat Mode"); | megdnn_assert(0, "Invalid RelayoutFormat Mode"); | ||||
break; | break; | ||||
@@ -352,7 +369,12 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { | |||||
CHECK_SRC(DefaultTensorFormat::make()); | CHECK_SRC(DefaultTensorFormat::make()); | ||||
dst = src; | dst = src; | ||||
break; | break; | ||||
case Param::Mode::NCHW_NCHW64: | |||||
dst = src; | |||||
break; | |||||
case Param::Mode::NCHW64_NCHW: | |||||
dst = src; | |||||
break; | |||||
default: | default: | ||||
megdnn_throw("Invalid relayout format mode"); | megdnn_throw("Invalid relayout format mode"); | ||||
break; | break; | ||||
@@ -633,6 +655,19 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, | |||||
exec_src = src.dimshuffle({3, 0, 1, 2, 4}); | exec_src = src.dimshuffle({3, 0, 1, 2, 4}); | ||||
exec_dst = dst; | exec_dst = dst; | ||||
break; | break; | ||||
case Param::Mode::NCHW_NCHW64: | |||||
// src is {N, C, H, W} | |||||
// dst is {N, C/64, H, W, 64} | |||||
exec_src = src.reshape({src[0], src[1] / 64, 64, src[2], src[3]}) | |||||
.dimshuffle({0, 1, 3, 4, 2}); | |||||
exec_dst = dst; | |||||
break; | |||||
case Param::Mode::NCHW64_NCHW: | |||||
// src is {N, C/64, H, W, 64} | |||||
// dst is {N, C, H, W} | |||||
exec_src = src.dimshuffle({0, 1, 4, 2, 3}); | |||||
exec_dst = dst; | |||||
break; | |||||
default: | default: | ||||
megdnn_assert(0, "Invalid RelayoutFormat Mode"); | megdnn_assert(0, "Invalid RelayoutFormat Mode"); | ||||
} | } | ||||
@@ -69,12 +69,9 @@ size_t ConvBiasForwardImpl::AlgoFallbackNCHWQS4::get_workspace_in_bytes( | |||||
void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec( | void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec( | ||||
const ExecArgs& args) const { | const ExecArgs& args) const { | ||||
using Format = Param::Format; | |||||
auto&& param = args.opr->param(); | |||||
auto&& fm = args.filter_meta; | |||||
auto layouts = make_underlying_tensor_layout( | auto layouts = make_underlying_tensor_layout( | ||||
*(args.src_layout), fm, *(args.bias_layout), *(args.z_layout), | |||||
*(args.dst_layout)); | |||||
*(args.src_layout), *(args.filter_layout), *(args.bias_layout), | |||||
*(args.z_layout), *(args.dst_layout)); | |||||
auto ws = get_workspace_bundle(args.workspace.raw_ptr, args); | auto ws = get_workspace_bundle(args.workspace.raw_ptr, args); | ||||
auto ws_src = ws.get(0); | auto ws_src = ws.get(0); | ||||
auto ws_filter = ws.get(1); | auto ws_filter = ws.get(1); | ||||
@@ -82,20 +79,27 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec( | |||||
void* ws_z = nullptr; | void* ws_z = nullptr; | ||||
if (args.z_layout->ndim > 0) | if (args.z_layout->ndim > 0) | ||||
ws_z = ws.get(4); | ws_z = ws.get(4); | ||||
auto&& stream = cuda_stream(args.opr->handle()); | |||||
auto nchw2nchw64 = [](const TensorND& src, void* raw_dptr) { | |||||
if (raw_dptr == nullptr) | |||||
// auto&& stream = cuda_stream(args.opr->handle()); | |||||
auto nchw2nchw64 = [&args](const TensorND& src, TensorND&& dst) { | |||||
if (dst.raw_ptr == nullptr) | |||||
return; | return; | ||||
auto relayout = args.handle->create_operator<RelayoutFormat>(); | |||||
relayout->param() = RelayoutFormat::Param::Mode::NCHW_NCHW64; | |||||
Workspace dummy; | |||||
relayout->exec(src, dst, dummy); | |||||
}; | }; | ||||
auto nchw642nchw = [](const TensorND& src, void* raw_dptr) { | |||||
auto nchw642nchw = [&args](const TensorND& src, TensorND&& dst) { | |||||
auto relayout = args.handle->create_operator<RelayoutFormat>(); | |||||
relayout->param() = RelayoutFormat::Param::Mode::NCHW64_NCHW; | |||||
Workspace dummy; | |||||
relayout->exec(src, dst, dummy); | |||||
}; | }; | ||||
// reformat src | // reformat src | ||||
nchw2nchw64(*(args.src_tensor), ws_src); | |||||
nchw2nchw64(*(args.src_tensor), {ws_src, layouts[0]}); | |||||
// reformat filter | // reformat filter | ||||
nchw2nchw64(*(args.filter_tensor), ws_filter); | |||||
nchw2nchw64(*(args.filter_tensor), {ws_filter, layouts[1]}); | |||||
// reformat z | // reformat z | ||||
nchw2nchw64(*(args.z_tensor), ws_z); | |||||
nchw2nchw64(*(args.z_tensor), {ws_z, layouts[3]}); | |||||
TensorND src_{ws_src, layouts[0]}, filter_{ws_filter, layouts[1]}, | TensorND src_{ws_src, layouts[0]}, filter_{ws_filter, layouts[1]}, | ||||
bias_{args.bias_tensor->raw_ptr, layouts[2]}, z_{ws_z, layouts[3]}, | bias_{args.bias_tensor->raw_ptr, layouts[2]}, z_{ws_z, layouts[3]}, | ||||
dst_{ws_dst, layouts[4]}; | dst_{ws_dst, layouts[4]}; | ||||
@@ -109,22 +113,22 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec( | |||||
args.preprocessed_filter}; | args.preprocessed_filter}; | ||||
m_underlying_algo.exec(args); | m_underlying_algo.exec(args); | ||||
// reformat dst | // reformat dst | ||||
nchw642nchw(dst_, args.dst_tensor->raw_ptr); | |||||
nchw642nchw(dst_, {args.dst_tensor->raw_ptr, args.dst_tensor->layout}); | |||||
} | } | ||||
SmallVector<TensorLayout> | SmallVector<TensorLayout> | ||||
ConvBiasForwardImpl::AlgoFallbackNCHWQS4::make_underlying_tensor_layout( | ConvBiasForwardImpl::AlgoFallbackNCHWQS4::make_underlying_tensor_layout( | ||||
const TensorLayout& src, const CanonizedFilterMeta& filter_meta, | |||||
const TensorLayout& src, const TensorLayout& filter, | |||||
const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
const TensorLayout& dst) const { | const TensorLayout& dst) const { | ||||
size_t n = src[0], ci = src[1], hi = src[2], wi = src[3]; | size_t n = src[0], ci = src[1], hi = src[2], wi = src[3]; | ||||
size_t co = dst[1], ho = dst[2], wo = dst[3]; | size_t co = dst[1], ho = dst[2], wo = dst[3]; | ||||
size_t fh = filter_meta.spatial[0], fw = filter_meta.spatial[1]; | |||||
size_t fh = filter[2], fw = filter[3]; | |||||
SmallVector<TensorLayout> rst; | SmallVector<TensorLayout> rst; | ||||
rst.emplace_back(TensorLayout{{n, ci / 64, hi, wi, 64}, src.dtype}); | rst.emplace_back(TensorLayout{{n, ci / 64, hi, wi, 64}, src.dtype}); | ||||
rst.emplace_back(TensorLayout{{co, ci / 64, fh, fw, 64}, filter.dtype}); | rst.emplace_back(TensorLayout{{co, ci / 64, fh, fw, 64}, filter.dtype}); | ||||
rst.emplace_back(TensorLayout{{1, co / 64, 1, 1, 64}, bias.dtype}); | rst.emplace_back(TensorLayout{{1, co / 64, 1, 1, 64}, bias.dtype}); | ||||
if (z.layout.ndim > 0) { | |||||
if (z.ndim > 0) { | |||||
rst.emplace_back(TensorLayout{{n, co / 64, ho, wo, 64}, z.dtype}); | rst.emplace_back(TensorLayout{{n, co / 64, ho, wo, 64}, z.dtype}); | ||||
} else { | } else { | ||||
rst.emplace_back(TensorLayout{}); | rst.emplace_back(TensorLayout{}); | ||||
@@ -134,15 +138,13 @@ ConvBiasForwardImpl::AlgoFallbackNCHWQS4::make_underlying_tensor_layout( | |||||
} | } | ||||
WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS4::get_workspace_bundle( | WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS4::get_workspace_bundle( | ||||
void* ptr, const SizeArgs& args) const { | |||||
void* raw_ptr, const SizeArgs& args) const { | |||||
size_t ws_size_src = args.src_layout->span().dist_byte(); | size_t ws_size_src = args.src_layout->span().dist_byte(); | ||||
size_t ws_size_filter = args.filter_layout->span().dist_byte(); | size_t ws_size_filter = args.filter_layout->span().dist_byte(); | ||||
size_t ws_size_dst = args.dst_layout->span().dist_byte(); | size_t ws_size_dst = args.dst_layout->span().dist_byte(); | ||||
auto&& param = args.opr->param(); | |||||
auto&& fm = args.filter_meta; | |||||
auto layouts = make_underlying_tensor_layout( | auto layouts = make_underlying_tensor_layout( | ||||
*(args.src_layout), fm, *(args.bias_layout), *(args.z_layout), | |||||
*(args.dst_layout)); | |||||
*(args.src_layout), *(args.filter_layout), *(args.bias_layout), | |||||
*(args.z_layout), *(args.dst_layout)); | |||||
SizeArgs args_{args.opr, | SizeArgs args_{args.opr, | ||||
layouts[0], | layouts[0], | ||||
layouts[1], | layouts[1], | ||||
@@ -78,29 +78,33 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
return handle()->create_operator<RelayoutForward>()->exec( | return handle()->create_operator<RelayoutForward>()->exec( | ||||
{src.raw_ptr, exec_src_layout}, {dst.raw_ptr, exec_dst_layout}); | {src.raw_ptr, exec_src_layout}, {dst.raw_ptr, exec_dst_layout}); | ||||
} | } | ||||
if (param().mode == Param::Mode::NCHW_NCHW4 || | |||||
param().mode == Param::Mode::NCHW4_NCHW || | |||||
param().mode == Param::Mode::NCHW_NCHW4_WEIGHT) { | |||||
bool is_trans_4bits = (param().mode == Param::Mode::NCHW_NCHW64 || | |||||
param().mode == Param::Mode::NCHW64_NCHW) && | |||||
(src_dtype.enumv() == DTypeEnum::QuantizedS4 || | |||||
src_dtype.enumv() == DTypeEnum::Quantized4Asymm); | |||||
bool is_nchw_nchw4 = param().mode == Param::Mode::NCHW_NCHW4 || | |||||
param().mode == Param::Mode::NCHW4_NCHW || | |||||
param().mode == Param::Mode::NCHW_NCHW4_WEIGHT; | |||||
if (is_trans_4bits || is_nchw_nchw4) { | |||||
bool is_usable = relayout_format::RelayoutFormatFast::usable( | bool is_usable = relayout_format::RelayoutFormatFast::usable( | ||||
src.layout, dst.layout); | src.layout, dst.layout); | ||||
megdnn_assert(is_usable, | megdnn_assert(is_usable, | ||||
"RelayoutFormatNCHW_NCHW4 kernel not usable for %s(%s) " | |||||
"to %s(%s)", | |||||
"RelayoutFormatFast kernel is not usable for " | |||||
"transforming %s(%s) to %s(%s).", | |||||
src.layout.to_string().c_str(), src.layout.dtype.name(), | src.layout.to_string().c_str(), src.layout.dtype.name(), | ||||
dst.layout.to_string().c_str(), dst.layout.dtype.name()); | dst.layout.to_string().c_str(), dst.layout.dtype.name()); | ||||
relayout_format::RelayoutFormatFast::exec(src, dst, | |||||
cuda_stream(this->handle()), | |||||
param().mode, param().group); | |||||
} else { | |||||
TensorLayout exec_src, exec_dst, exec_workspace; | |||||
deduce_exec_layout(src.layout, dst.layout, exec_workspace, exec_src, | |||||
exec_dst); | |||||
TensorND exec_src_nd{src.raw_ptr, exec_src}; | |||||
TensorND exec_dst_nd{dst.raw_ptr, exec_dst}; | |||||
handle()->create_operator<RelayoutForward>()->exec(exec_src_nd, | |||||
exec_dst_nd); | |||||
return relayout_format::RelayoutFormatFast::exec( | |||||
src, dst, cuda_stream(this->handle()), param().mode, | |||||
param().group); | |||||
} | } | ||||
// fallback impls | |||||
TensorLayout exec_src, exec_dst, exec_workspace; | |||||
deduce_exec_layout(src.layout, dst.layout, exec_workspace, exec_src, | |||||
exec_dst); | |||||
TensorND exec_src_nd{src.raw_ptr, exec_src}; | |||||
TensorND exec_dst_nd{dst.raw_ptr, exec_dst}; | |||||
handle()->create_operator<RelayoutForward>()->exec(exec_src_nd, | |||||
exec_dst_nd); | |||||
} | } | ||||
size_t RelayoutFormatImpl::get_workspace_in_bytes( | size_t RelayoutFormatImpl::get_workspace_in_bytes( | ||||
@@ -24,6 +24,8 @@ inline void get_scale_zeropoint(const DType& tensor_dtype, float& scale, | |||||
scale = tensor_dtype.param<dtype::Quantized8Asymm>().scale; | scale = tensor_dtype.param<dtype::Quantized8Asymm>().scale; | ||||
} else if (tensor_dtype.enumv() == DTypeEnum::QuantizedS8) { | } else if (tensor_dtype.enumv() == DTypeEnum::QuantizedS8) { | ||||
scale = tensor_dtype.param<dtype::QuantizedS8>().scale; | scale = tensor_dtype.param<dtype::QuantizedS8>().scale; | ||||
} else if (tensor_dtype.enumv() == DTypeEnum::QuantizedS4) { | |||||
scale = tensor_dtype.param<dtype::QuantizedS4>().scale; | |||||
} | } | ||||
} | } | ||||
@@ -39,9 +41,8 @@ void relayout_format::RelayoutFormatFast::exec(const TensorND& src, | |||||
cudaStream_t stream, | cudaStream_t stream, | ||||
RelayoutFormat::Param::Mode mode, | RelayoutFormat::Param::Mode mode, | ||||
int group) { | int group) { | ||||
size_t ih = src.layout[2]; | |||||
size_t iw = src.layout[3]; | |||||
size_t hw = ih * iw; | |||||
auto&& stype = src.layout.dtype; | |||||
auto&& dtype = dst.layout.dtype; | |||||
float src_scale = 1.f; | float src_scale = 1.f; | ||||
float dst_scale = 1.f; | float dst_scale = 1.f; | ||||
uint8_t src_zero_point = 0; | uint8_t src_zero_point = 0; | ||||
@@ -51,22 +52,28 @@ void relayout_format::RelayoutFormatFast::exec(const TensorND& src, | |||||
if (src.layout.dtype.enumv() == DTypeEnum::Uint8) { | if (src.layout.dtype.enumv() == DTypeEnum::Uint8) { | ||||
src_zero_point = 128; | src_zero_point = 128; | ||||
} | } | ||||
if (mode == RelayoutFormat::Param::Mode::NCHW_NCHW4) { | |||||
if (hw % 4 == 0) { | |||||
relayout_format_cuda_nchw_nchw4<4>(src, dst, stream, src_scale, | |||||
if (mode == RelayoutFormat::Param::Mode::NCHW_NCHW4 || | |||||
mode == RelayoutFormat::Param::Mode::NCHW_NCHW64) { | |||||
return relayout_format_cuda_nchw_nchwx(src, dst, stream, src_scale, | |||||
dst_scale, src_zero_point, | dst_scale, src_zero_point, | ||||
dst_zero_point, group); | dst_zero_point, group); | ||||
} else { | |||||
relayout_format_cuda_nchw_nchw4<1>(src, dst, stream, src_scale, | |||||
} else if (mode == RelayoutFormat::Param::Mode::NCHW64_NCHW) { | |||||
megdnn_assert(group == 1, | |||||
"RelayoutFormat kernel only support transforming NCHW64 " | |||||
"to NCHW with group = 1(group:%d)", | |||||
group); | |||||
return relayout_format_cuda_nchwx_nchw(src, dst, stream, src_scale, | |||||
dst_scale, src_zero_point, | dst_scale, src_zero_point, | ||||
dst_zero_point, group); | |||||
} | |||||
dst_zero_point); | |||||
} else if (mode == RelayoutFormat::Param::Mode::NCHW_NCHW4_WEIGHT) { | } else if (mode == RelayoutFormat::Param::Mode::NCHW_NCHW4_WEIGHT) { | ||||
relayout_format_cuda_nchw_nchw4_weight(src, dst, stream); | |||||
return relayout_format_cuda_nchw_nchw4_weight(src, dst, stream); | |||||
} else if (mode == RelayoutFormat::Param::Mode::NCHW4_NCHW) { | } else if (mode == RelayoutFormat::Param::Mode::NCHW4_NCHW) { | ||||
relayout_format_cuda_nchw4_nchw(src, dst, stream, group); | |||||
return relayout_format_cuda_nchw4_nchw(src, dst, stream, group); | |||||
} else { | } else { | ||||
megdnn_throw("only support nchw_nchw4 nchw4_nchw layout_format"); | |||||
megdnn_throw( | |||||
"only support nchw_nchw64/nchw64_nchw/nchw_nchw4/nchw4_nchw " | |||||
"layout_format"); | |||||
} | } | ||||
} | } | ||||
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -19,14 +19,11 @@ namespace megdnn { | |||||
namespace cuda { | namespace cuda { | ||||
namespace relayout_format { | namespace relayout_format { | ||||
template <int pack_w = 1> | |||||
void relayout_format_cuda_nchw_nchw4(const TensorND& src, const TensorND& dst, | |||||
const cudaStream_t& stream, | |||||
const float src_scale = 1.f, | |||||
const float dst_scale = 1.f, | |||||
const uint8_t src_zero_point = 0, | |||||
const uint8_t dst_zero_point = 0, | |||||
const int group = 1); | |||||
void relayout_format_cuda_nchw_nchwx( | |||||
const TensorND& src, const TensorND& dst, const cudaStream_t& stream, | |||||
const float src_scale = 1.f, const float dst_scale = 1.f, | |||||
const uint8_t src_zero_point = 0, const uint8_t dst_zero_point = 0, | |||||
const int group = 1); | |||||
bool relayout_format_cuda_usable(const TensorLayout& src_layout, | bool relayout_format_cuda_usable(const TensorLayout& src_layout, | ||||
const TensorLayout& dst_layout); | const TensorLayout& dst_layout); | ||||
@@ -35,6 +32,13 @@ void relayout_format_cuda_nchw4_nchw(const TensorND& src, const TensorND& dst, | |||||
const cudaStream_t& stream, | const cudaStream_t& stream, | ||||
const int group); | const int group); | ||||
void relayout_format_cuda_nchwx_nchw(const TensorND& src, const TensorND& dst, | |||||
const cudaStream_t& stream, | |||||
const float src_scale = 1.f, | |||||
const float dst_scale = 1.f, | |||||
const uint8_t src_zero_point = 0, | |||||
const uint8_t dst_zero_point = 0); | |||||
void relayout_format_cuda_nchw_nchw4_weight(const TensorND& src, | void relayout_format_cuda_nchw_nchw4_weight(const TensorND& src, | ||||
const TensorND& dst, | const TensorND& dst, | ||||
const cudaStream_t& stream); | const cudaStream_t& stream); | ||||
@@ -110,6 +110,12 @@ MEGDNN_NORETURN void report_error(const char* msg); | |||||
template <typename T, size_t N> | template <typename T, size_t N> | ||||
struct array_wrapper { | struct array_wrapper { | ||||
T data[N]; | T data[N]; | ||||
MEGDNN_DEVICE __forceinline__ T& operator[](size_t pos) { | |||||
return reinterpret_cast<T&>(data[pos]); | |||||
} | |||||
MEGDNN_DEVICE __forceinline__ T const& operator[](size_t pos) const { | |||||
return reinterpret_cast<T const&>(data[pos]); | |||||
} | |||||
}; | }; | ||||
/*! | /*! | ||||
@@ -207,12 +213,29 @@ struct CudaDTypeParamImpl<dt_quint4> : DTypeParamImpl<dt_quint4> { | |||||
CudaDTypeParamImpl(const DTypeParamImpl<dt_quint4>& param) | CudaDTypeParamImpl(const DTypeParamImpl<dt_quint4>& param) | ||||
: CudaDTypeParamImpl(param.scale, param.zero_point) {} | : CudaDTypeParamImpl(param.scale, param.zero_point) {} | ||||
__device__ uint8_t quantize(float in) const { | |||||
__device__ dt_quint4 quantize(float in) const { | |||||
float v = in * inv_scale; | float v = in * inv_scale; | ||||
v = roundf(v); | v = roundf(v); | ||||
v = v + zero_point; | v = v + zero_point; | ||||
v = fmin(fmax(0.f, v), 15.f); | v = fmin(fmax(0.f, v), 15.f); | ||||
return static_cast<uint8_t>(v); | |||||
return static_cast<dt_quint4>(v); | |||||
} | |||||
}; | |||||
template <> | |||||
struct CudaDTypeParamImpl<dt_qint4> : DTypeParamImpl<dt_qint4> { | |||||
float inv_scale; | |||||
CudaDTypeParamImpl() = default; | |||||
CudaDTypeParamImpl(float scale) | |||||
: DTypeParamImpl<dt_qint4>(scale), inv_scale(1.0f / scale) {} | |||||
CudaDTypeParamImpl(const DTypeParamImpl<dt_qint4>& param) | |||||
: CudaDTypeParamImpl(param.scale) {} | |||||
__device__ dt_qint4 quantize(float in) const { | |||||
float v = in * inv_scale; | |||||
v = roundf(v); | |||||
v = fmin(fmax(-8.f, v), 7.f); | |||||
return static_cast<dt_qint4>(v); | |||||
} | } | ||||
}; | }; | ||||
@@ -351,6 +374,110 @@ MEGDNN_DEVICE __forceinline__ static float4 operator+(float4 lval, | |||||
return make_float4(lval.x + rval.x, lval.y + rval.y, lval.z + rval.z, | return make_float4(lval.x + rval.x, lval.y + rval.y, lval.z + rval.z, | ||||
lval.w + rval.w); | lval.w + rval.w); | ||||
} | } | ||||
MEGDNN_DEVICE __forceinline__ static int transform_int8_to_int4x8( | |||||
int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { | |||||
unsigned out; | |||||
#if __CUDA_ARCH__ >= 750 | |||||
asm volatile( | |||||
"{ .reg .u32 r4;" | |||||
"cvt.pack.sat.s4.s32.b32 r4, %8, %7, 0;" | |||||
"cvt.pack.sat.s4.s32.b32 r4, %6, %5, r4;" | |||||
"cvt.pack.sat.s4.s32.b32 r4, %4, %3, r4;" | |||||
"cvt.pack.sat.s4.s32.b32 %0, %2, %1, r4;" | |||||
"}" | |||||
: "=r"(out) | |||||
: "r"(s0), "r"(s1), "r"(s2), "r"(s3), "r"(s4), "r"(s5), "r"(s6), | |||||
"r"(s7)); | |||||
#else | |||||
#define CVT_SAT_S4_S32(r, bits) \ | |||||
r = r <= -8 ? -8 : r; \ | |||||
r = r > 7 ? 7 : r; \ | |||||
r = (((unsigned)r & 0xf) << bits); | |||||
CVT_SAT_S4_S32(s0, 0) | |||||
CVT_SAT_S4_S32(s1, 4) | |||||
CVT_SAT_S4_S32(s2, 8) | |||||
CVT_SAT_S4_S32(s3, 12) | |||||
CVT_SAT_S4_S32(s4, 16) | |||||
CVT_SAT_S4_S32(s5, 20) | |||||
CVT_SAT_S4_S32(s6, 24) | |||||
CVT_SAT_S4_S32(s7, 28) | |||||
out = s0 + s1 + s2 + s3 + s4 + s5 + s6 + s7; | |||||
#undef CVT_SAT_S4_S32 | |||||
#endif | |||||
return reinterpret_cast<int const&>(out); | |||||
} | |||||
MEGDNN_DEVICE __forceinline__ static int transform_int8_to_uint4x8( | |||||
int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { | |||||
unsigned out; | |||||
#if __CUDA_ARCH__ >= 750 | |||||
asm volatile( | |||||
"{ .reg .u32 r4;" | |||||
"cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;" | |||||
"cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;" | |||||
"cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;" | |||||
"cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;" | |||||
"}" | |||||
: "=r"(out) | |||||
: "r"(s0), "r"(s1), "r"(s2), "r"(s3), "r"(s4), "r"(s5), "r"(s6), | |||||
"r"(s7)); | |||||
#else | |||||
#define CVT_SAT_U4_S32(r, bits) \ | |||||
r = r <= 0 ? 0 : r; \ | |||||
r = r > 15 ? 15 : r; \ | |||||
r = (((unsigned)r & 0xf) << bits); | |||||
CVT_SAT_U4_S32(s0, 0) | |||||
CVT_SAT_U4_S32(s1, 4) | |||||
CVT_SAT_U4_S32(s2, 8) | |||||
CVT_SAT_U4_S32(s3, 12) | |||||
CVT_SAT_U4_S32(s4, 16) | |||||
CVT_SAT_U4_S32(s5, 20) | |||||
CVT_SAT_U4_S32(s6, 24) | |||||
CVT_SAT_U4_S32(s7, 28) | |||||
out = s0 + s1 + s2 + s3 + s4 + s5 + s6 + s7; | |||||
#undef CVT_SAT_U4_S32 | |||||
#endif | |||||
return reinterpret_cast<int const&>(out); | |||||
} | |||||
template <bool signedness> | |||||
MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(unsigned storage, | |||||
unsigned bits); | |||||
template <> | |||||
MEGDNN_DEVICE __forceinline__ int unpack_integer_4bits<true>(unsigned storage, | |||||
unsigned bits) { | |||||
uint8_t result = (uint8_t)((unsigned)(storage >> bits) & 0xf); | |||||
static constexpr uint8_t mask = (uint8_t)((1 << 4) - 1); | |||||
return (result & uint8_t(1 << 3)) ? ((int)(result) | ~(int)(mask)) | |||||
: (int)(result); | |||||
} | |||||
template <> | |||||
MEGDNN_DEVICE __forceinline__ int unpack_integer_4bits<false>(unsigned storage, | |||||
unsigned bits) { | |||||
uint8_t result = (uint8_t)((unsigned)(storage >> bits) & 0xf); | |||||
return (int)(result); | |||||
} | |||||
MEGDNN_DEVICE __forceinline__ static void transform_int4x8_to_int8( | |||||
int (&result)[8], const int& source) { | |||||
#pragma unroll | |||||
for (int i = 0; i < 8; i++) { | |||||
result[i] = unpack_integer_4bits<true>( | |||||
reinterpret_cast<unsigned const&>(source), (i << 2)); | |||||
} | |||||
} | |||||
MEGDNN_DEVICE __forceinline__ static void transform_uint4x8_to_int8( | |||||
int (&result)[8], const int& source) { | |||||
#pragma unroll | |||||
for (int i = 0; i < 8; i++) { | |||||
result[i] = unpack_integer_4bits<false>( | |||||
reinterpret_cast<unsigned const&>(source), (i << 2)); | |||||
} | |||||
} | |||||
#endif | #endif | ||||
} // namespace cuda | } // namespace cuda | ||||
} // namespace megdnn | } // namespace megdnn | ||||