GitOrigin-RevId: e9ecbcf2e2
release-1.5
@@ -252,10 +252,10 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, | |||
megdnn_assert(dst[1] % param().group == 0); | |||
break; | |||
case Param::Mode::NCHW_NCHW64: | |||
megdnn_assert(src.ndim == 4 && (src[1] % 64) == 0); | |||
megdnn_assert(src.ndim == 4); | |||
dst.ndim = 5; | |||
dst[0] = src[0]; | |||
dst[1] = src[1] / 64; | |||
dst[1] = div_ceil(src[1], 64_z); | |||
dst[2] = src[2]; | |||
dst[3] = src[3]; | |||
dst[4] = 64; | |||
@@ -264,7 +264,7 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, | |||
megdnn_assert(src.ndim == 5); | |||
dst.ndim = 4; | |||
dst[0] = src[0]; | |||
dst[1] = src[1] * 64; | |||
dst[1] = param().oc == 0 ? src[1] * 64 : param().oc; | |||
dst[2] = src[2]; | |||
dst[3] = src[3]; | |||
break; | |||
@@ -483,12 +483,11 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, | |||
case Param::Mode::NCHW4_NCHW: | |||
// nchw to nchw4 | |||
{ | |||
megdnn_assert(src.format == dst.format); | |||
exec_workspace = | |||
TensorLayout({src[0], src[1] * 4, src[2], src[3]}, | |||
src.dtype, src.format) | |||
.reshape({src[0], src[1], 4, src[2], src[3]}) | |||
.dimshuffle({0, 1, 3, 4, 2}); | |||
exec_src = src; | |||
dst.dtype, dst.format); | |||
exec_src = src.dimshuffle({0, 1, 4, 2, 3}); | |||
exec_dst = dst; | |||
} | |||
break; | |||
@@ -658,13 +657,20 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, | |||
case Param::Mode::NCHW_NCHW64: | |||
// src is {N, C, H, W} | |||
// dst is {N, C/64, H, W, 64} | |||
exec_src = src.reshape({src[0], src[1] / 64, 64, src[2], src[3]}) | |||
exec_workspace = TensorLayout( | |||
{src[0], round_up(src[1], 64_z), src[2], src[3]}, | |||
src.dtype); | |||
exec_src = exec_workspace | |||
.reshape({src[0], div_ceil(src[1], 64_z), 64, | |||
src[2], src[3]}) | |||
.dimshuffle({0, 1, 3, 4, 2}); | |||
exec_dst = dst; | |||
break; | |||
case Param::Mode::NCHW64_NCHW: | |||
// src is {N, C/64, H, W, 64} | |||
// dst is {N, C, H, W} | |||
exec_workspace = TensorLayout({src[0], src[1] * 64, src[2], src[3]}, | |||
dst.dtype); | |||
exec_src = src.dimshuffle({0, 1, 4, 2, 3}); | |||
exec_dst = dst; | |||
break; | |||
@@ -0,0 +1,149 @@ | |||
/** | |||
* \file dnn/src/cuda/relayout_format/helper.cuh | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace relayout_format { | |||
#define devfunc __forceinline__ __device__ | |||
template <int size_nbits> | |||
devfunc int make_zero(int zero_point); | |||
template <> | |||
devfunc int make_zero<4>(int zero_point) { | |||
return transform_int8_to_uint4x8(zero_point, zero_point, zero_point, | |||
zero_point, zero_point, zero_point, | |||
zero_point, zero_point); | |||
} | |||
template <typename AccessType, int LoadBytes> | |||
struct global_load_with_zero_point; | |||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||
// | |||
// Specializations | |||
// | |||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||
// The redundant mov PTX instruction is used to enforce the compiler to | |||
// initialize data to zero before ld.global | |||
template <typename AccessType> | |||
struct global_load_with_zero_point<AccessType, 32> { | |||
devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||
bool pred_guard, int zero_point) { | |||
uint4* data = reinterpret_cast<uint4*>(&D); | |||
asm volatile( | |||
"{\n" | |||
" .reg .pred p;\n" | |||
" setp.ne.b32 p, %9, 0;\n" | |||
" mov.b32 %0, %10;\n" | |||
" mov.b32 %1, %10;\n" | |||
" mov.b32 %2, %10;\n" | |||
" mov.b32 %3, %10;\n" | |||
" mov.b32 %4, %10;\n" | |||
" mov.b32 %5, %10;\n" | |||
" mov.b32 %6, %10;\n" | |||
" mov.b32 %7, %10;\n" | |||
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n" | |||
" @p ld.global.v4.u32 {%4, %5, %6, %7}, [%11];\n" | |||
"}\n" | |||
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), | |||
"=r"(data[0].w), "=r"(data[1].x), "=r"(data[1].y), | |||
"=r"(data[1].z), "=r"(data[1].w) | |||
: "l"(ptr), "r"((int)pred_guard), | |||
"r"(reinterpret_cast<unsigned&>(zero_point)), | |||
"l"(((uint8_t*)ptr) + 16)); | |||
} | |||
}; | |||
template <typename AccessType> | |||
struct global_load_with_zero_point<AccessType, 16> { | |||
devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||
bool pred_guard, int zero_point) { | |||
uint4& data = reinterpret_cast<uint4&>(D); | |||
asm volatile( | |||
"{\n" | |||
" .reg .pred p;\n" | |||
" setp.ne.b32 p, %5, 0;\n" | |||
" mov.b32 %0, %6;\n" | |||
" mov.b32 %1, %6;\n" | |||
" mov.b32 %2, %6;\n" | |||
" mov.b32 %3, %6;\n" | |||
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n" | |||
"}\n" | |||
: "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) | |||
: "l"(ptr), "r"((int)pred_guard), | |||
"r"(reinterpret_cast<unsigned&>(zero_point))); | |||
} | |||
}; | |||
template <typename AccessType> | |||
struct global_load_with_zero_point<AccessType, 8> { | |||
devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||
bool pred_guard, int zero_point) { | |||
uint2& data = reinterpret_cast<uint2&>(D); | |||
asm volatile( | |||
"{\n" | |||
" .reg .pred p;\n" | |||
" setp.ne.b32 p, %3, 0;\n" | |||
" mov.b32 %0, %4;\n" | |||
" mov.b32 %1, %4;\n" | |||
" @p ld.global.v2.u32 {%0, %1}, [%2];\n" | |||
"}\n" | |||
: "=r"(data.x), "=r"(data.y) | |||
: "l"(ptr), "r"((int)pred_guard), | |||
"r"(reinterpret_cast<unsigned&>(zero_point))); | |||
} | |||
}; | |||
template <typename AccessType> | |||
struct global_load_with_zero_point<AccessType, 4> { | |||
devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||
bool pred_guard, int zero_point) { | |||
unsigned& data = reinterpret_cast<unsigned&>(D); | |||
asm volatile( | |||
"{\n" | |||
" .reg .pred p;\n" | |||
" setp.ne.b32 p, %2, 0;\n" | |||
" mov.b32 %0, %3;\n" | |||
" @p ld.global.u32 %0, [%1];\n" | |||
"}\n" | |||
: "=r"(data) | |||
: "l"(ptr), "r"((int)pred_guard), | |||
"r"(reinterpret_cast<unsigned&>(zero_point))); | |||
} | |||
}; | |||
template <typename AccessType> | |||
struct global_load_with_zero_point<AccessType, 1> { | |||
devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||
bool pred_guard, int zero_point) { | |||
if (pred_guard) | |||
D = *(reinterpret_cast<AccessType const*>(ptr)); | |||
else { | |||
unsigned uv = reinterpret_cast<unsigned&>(zero_point); | |||
uint8_t& data = reinterpret_cast<uint8_t&>(D); | |||
data = uv & 0xff; | |||
} | |||
} | |||
}; | |||
#undef devfunc | |||
} // namespace relayout_format | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -18,6 +18,7 @@ | |||
#pragma GCC diagnostic pop | |||
#include "src/cuda/query_blocksize.cuh" | |||
#include "src/cuda/relayout_format/relayout_format.cuh" | |||
#include "src/cuda/relayout_format/helper.cuh" | |||
using namespace megdnn; | |||
using namespace cuda; | |||
@@ -728,17 +729,18 @@ struct Translayout<64, 2, SrcType, dtype::Quantized4Asymm, | |||
#undef pack | |||
template <typename DstType> | |||
inline __device__ DstType make_zero_pad(const char zero_point) { | |||
inline __device__ DstType make_zero_pad(const uint8_t zero_point) { | |||
return zero_point; | |||
} | |||
template <> | |||
inline __device__ char4 make_zero_pad<char4>(const char zero_point) { | |||
return {zero_point, zero_point, zero_point, zero_point}; | |||
inline __device__ char4 make_zero_pad<char4>(const uint8_t zero_point) { | |||
char izp = reinterpret_cast<const char&>(zero_point); | |||
return {izp, izp, izp, izp}; | |||
} | |||
template <> | |||
inline __device__ int4 make_zero_pad<int4>(const char zero_point) { | |||
inline __device__ int4 make_zero_pad<int4>(const uint8_t zero_point) { | |||
return {zero_point, zero_point, zero_point, zero_point}; | |||
} | |||
@@ -767,7 +769,7 @@ inline __device__ void write_helper<array_wrapper<uint8_t, 32>>( | |||
: "l"(ptr_), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), | |||
"r"(data[0].w), "l"(((uint8_t*)ptr_) + 16), "r"(data[1].x), | |||
"r"(data[1].y), "r"(data[1].z), "r"(data[1].w)); | |||
}; | |||
} | |||
template <bool with_pad, int pack_w, int pack_c, bool same_scale, bool all_pad, | |||
typename SrcType, typename DstType, typename DnnSrcType, | |||
@@ -825,7 +827,7 @@ struct RelayoutKern { | |||
const SrcType* src, DstType* dst, const int ic_stride, | |||
const int remain_ic, | |||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||
const char zero_point) { | |||
const uint8_t zero_point) { | |||
InnerDtype read_channel[pack_c]; | |||
if (all_pad) { | |||
const InnerDtype zero_pad = make_zero_pad<InnerDtype>(zero_point); | |||
@@ -855,7 +857,7 @@ __global__ void kern_nchw_nchwx( | |||
const SrcType* src, DstType* dst, int in_n, int ic, int ihw, | |||
int n_stride_src, int ic_stride, int n_stride_dst, int oc_stride, | |||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale> post_process, | |||
const char zero_point, const int group, const int ocpg) { | |||
const uint8_t zero_point, const int group, const int ocpg) { | |||
static constexpr int size_src_type = sizeof(SrcType); | |||
static constexpr int size_dst_type = sizeof(DstType); | |||
#ifndef MEGDNN_COMMA | |||
@@ -1072,6 +1074,7 @@ public: | |||
MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { | |||
pointer += (c_idx / pack_size) * chan_stride_in_elements + | |||
hw_idx * pack_size * size_nbits / (8 * sizeof(Type)); | |||
channel -= c_idx; | |||
} | |||
MEGDNN_DEVICE __forceinline__ void add_pointer_offset( | |||
@@ -1079,7 +1082,7 @@ public: | |||
pointer += offset_in_type; | |||
} | |||
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag) { | |||
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) { | |||
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); | |||
Type* pointer_ = pointer; | |||
#pragma unroll | |||
@@ -1090,11 +1093,12 @@ public: | |||
(lane_size_in_type / pack_size_in_type) + | |||
j; | |||
bool guard = i < channel; | |||
cutlass::arch::global_load<AccessType, pack_size_in_byte>( | |||
relayout_format::global_load_with_zero_point<AccessType, | |||
pack_size_in_byte>( | |||
frag_ptr[frag_idx], | |||
reinterpret_cast<void*>(pointer_ + | |||
j * pack_size_in_type), | |||
guard); | |||
guard, zero_point); | |||
} | |||
pointer_ += chan_stride_in_elements; | |||
} | |||
@@ -1173,6 +1177,7 @@ public: | |||
MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { | |||
pointer += (c_idx / pack_size) * chan_stride_in_elements; | |||
channel -= c_idx; | |||
#pragma unroll | |||
for (int i = 0; i < mask_size; ++i) { | |||
mask[i] = 0; | |||
@@ -1201,7 +1206,7 @@ public: | |||
pointer += offset_in_type; | |||
} | |||
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag) { | |||
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) { | |||
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); | |||
Type* pointer_ = pointer; | |||
#pragma unroll | |||
@@ -1214,9 +1219,11 @@ public: | |||
int mask_index = (frag_idx >> 5); | |||
int mask_shift = (frag_idx & 0x1f); | |||
bool guard = (mask[mask_index] & (1 << mask_shift)); | |||
cutlass::arch::global_load<AccessType, pack_size_in_byte>( | |||
relayout_format::global_load_with_zero_point<AccessType, | |||
pack_size_in_byte>( | |||
frag_ptr[frag_idx], | |||
reinterpret_cast<void*>(pointer_ + stride[j]), guard); | |||
reinterpret_cast<void*>(pointer_ + stride[j]), guard, | |||
zero_point); | |||
} | |||
pointer_ += chan_stride_in_elements; | |||
} | |||
@@ -1306,11 +1313,13 @@ struct RelayoutProblem { | |||
int batch_size; | |||
int channels; | |||
int hw; | |||
int zero_point; | |||
MEGDNN_HOST MEGDNN_DEVICE Param(SrcIterator src_iterator_, | |||
DstIterator dst_iterator_, | |||
CudaPostProcess post_process_, | |||
int n_stride_src_, int n_stride_dst_, | |||
int batch_size_, int channels_, int hw_) | |||
int batch_size_, int channels_, int hw_, | |||
int zero_point_) | |||
: src_iterator{src_iterator_}, | |||
dst_iterator{dst_iterator_}, | |||
post_process{post_process_}, | |||
@@ -1318,7 +1327,8 @@ struct RelayoutProblem { | |||
n_stride_dst{n_stride_dst_}, | |||
batch_size{batch_size_}, | |||
channels{channels_}, | |||
hw{hw_} {} | |||
hw{hw_}, | |||
zero_point{zero_point_} {} | |||
}; | |||
}; | |||
@@ -1345,7 +1355,9 @@ __global__ void relayout_kern(typename RelayoutProblem_::Param param) { | |||
param.dst_iterator.initialize(c_idx, hw_idx); | |||
typename SrcIterator::Fragment src_frag; | |||
typename DstIterator::Fragment dst_frag; | |||
param.src_iterator.load(src_frag); | |||
int zp = relayout_format::make_zero<SrcIterator::size_nbits>( | |||
param.zero_point); | |||
param.src_iterator.load(src_frag, zp); | |||
RelayoutProblem_::Transpose::trans( | |||
reinterpret_cast<typename SrcIterator::Fragment&>(dst_frag), | |||
src_frag, param.post_process); | |||
@@ -1382,7 +1394,8 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( | |||
stype.name(), dtype.name()); | |||
#undef DEF | |||
// no padding | |||
if (src.layout.stride[2] == static_cast<ptrdiff_t>(src.layout[3])) { | |||
if (stype.enumv().ev != DTypeEnum::Ev::QuantizedS4 && | |||
stype.enumv().ev != DTypeEnum::Ev::Quantized4Asymm) { | |||
const int in_n = src.layout[0]; | |||
const int out_n = dst.layout[0]; | |||
const int ic = src.layout[1]; | |||
@@ -1428,18 +1441,10 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( | |||
DISPATCH_RAW(false, 4, 4, _src_type, _dst_type, char, char, 8); \ | |||
DISPATCH_RAW(true, 1, 4, _src_type, _dst_type, char, char, 8); \ | |||
DISPATCH_RAW(false, 1, 4, _src_type, _dst_type, char, char, 8); | |||
#define DISPATCH_4BITS(_src_type, _dst_type) \ | |||
DISPATCH_RAW(true, 8, 64, _src_type, _dst_type, char, char, 4); \ | |||
DISPATCH_RAW(false, 8, 64, _src_type, _dst_type, char, char, 4); \ | |||
DISPATCH_RAW(true, 2, 64, _src_type, _dst_type, char, char, 4); \ | |||
DISPATCH_RAW(false, 2, 64, _src_type, _dst_type, char, char, 4); | |||
DISPATCH_INT(QuantizedS32, QuantizedS32); | |||
DISPATCH_BYTE(Uint8, QuantizedS8); | |||
DISPATCH_BYTE(Quantized8Asymm, QuantizedS8); | |||
DISPATCH_BYTE(QuantizedS8, QuantizedS8); | |||
DISPATCH_4BITS(QuantizedS4, QuantizedS4); | |||
DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); | |||
#undef DISPATCH_4BITS | |||
#undef DISPATCH_BYTE | |||
#undef DISPATCH_INT | |||
#undef DISPATCH_RAW | |||
@@ -1450,7 +1455,8 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( | |||
} else { | |||
megdnn_assert(src_layout.dtype.is_low_bit()); | |||
int n = src.layout[0]; | |||
int c = src.layout[1]; | |||
int ic = src.layout[1]; | |||
int oc = dst.layout[1] * 64; | |||
int h = src.layout[2]; | |||
// align to byte | |||
int w = src.layout[3]; | |||
@@ -1460,12 +1466,13 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( | |||
int ic_stride = src_layout.stride[1]; | |||
int n_stride_dst = dst_layout.stride[0]; | |||
int oc_stride = dst_layout.stride[1]; | |||
int problem_size = n * (c / pack_oc) * hw; | |||
int problem_size = n * (oc / pack_oc) * hw; | |||
bool same_scale = src_scale == dst_scale; | |||
#define DISPATCH_RAW(_same_scale, _pack_w, _pack_oc, _src_type, _dst_type, \ | |||
_src_c_type, _dst_c_type, _size_nbits) \ | |||
if (same_scale == _same_scale && hw % _pack_w == 0 && \ | |||
stype.enumv().ev == DTypeEnum::Ev::_src_type && \ | |||
bool padding = w % 2 != 0; | |||
#define DISPATCH_RAW(_padding, _same_scale, _pack_w, _pack_oc, _src_type, \ | |||
_dst_type, _src_c_type, _dst_c_type, _size_nbits) \ | |||
if (padding == _padding && same_scale == _same_scale && \ | |||
hw % _pack_w == 0 && stype.enumv().ev == DTypeEnum::Ev::_src_type && \ | |||
dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ | |||
using InnerDtype_ = typename DTypeRWHelper< \ | |||
typename DTypeTrait<dtype::_src_type>::ctype, \ | |||
@@ -1473,8 +1480,10 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( | |||
using SrcIterator_ = \ | |||
TensorIteratorOverChannel<InnerDtype_, 1, _pack_oc, _pack_w, \ | |||
_size_nbits>; \ | |||
using DstIterator_ = MaskedTensorIteratorOverChannel< \ | |||
_dst_c_type, _pack_oc, _pack_oc, _pack_w, _size_nbits>; \ | |||
using DstIterator_ = \ | |||
typename TensorIteratorPolicy<_padding, _dst_c_type, _pack_oc, \ | |||
_pack_oc, _pack_w, \ | |||
_size_nbits>::TensorIterator; \ | |||
using CudaPostProcess_ = \ | |||
CudaPostProcess<dtype::_src_type, dtype::_dst_type, \ | |||
_same_scale>; \ | |||
@@ -1489,17 +1498,18 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( | |||
n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(_dst_c_type)); \ | |||
oc_stride = oc_stride * _size_nbits / (8 * sizeof(_dst_c_type)); \ | |||
typename RelayoutProblem_::Param param{ \ | |||
SrcIterator_{(InnerDtype_*)src.raw_ptr, ic_stride, c, w, \ | |||
SrcIterator_{(InnerDtype_*)src.raw_ptr, ic_stride, ic, w, \ | |||
w_pad}, \ | |||
DstIterator_{(_dst_c_type*)dst.raw_ptr, oc_stride, c, w, \ | |||
DstIterator_{(_dst_c_type*)dst.raw_ptr, oc_stride, oc, w, \ | |||
w_pad}, \ | |||
CudaPostProcess_{src_scale, src_zero_point, dst_scale, \ | |||
dst_zero_point}, \ | |||
n_stride_src, \ | |||
n_stride_dst, \ | |||
n, \ | |||
c, \ | |||
hw}; \ | |||
oc, \ | |||
hw, \ | |||
src_zero_point}; \ | |||
auto kernel = relayout_kern<RelayoutProblem_>; \ | |||
int nr_threads = query_blocksize_for_kernel(kernel); \ | |||
nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \ | |||
@@ -1507,11 +1517,15 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( | |||
const dim3 thread_dim(nr_threads); \ | |||
return kernel<<<block_dim, thread_dim, 0, stream>>>(param); \ | |||
} | |||
#define DISPATCH_4BITS(_src_type, _dst_type) \ | |||
DISPATCH_RAW(true, 8, 64, _src_type, _dst_type, char, char, 4); \ | |||
DISPATCH_RAW(false, 8, 64, _src_type, _dst_type, char, char, 4); \ | |||
DISPATCH_RAW(true, 2, 64, _src_type, _dst_type, char, char, 4); \ | |||
DISPATCH_RAW(false, 2, 64, _src_type, _dst_type, char, char, 4); | |||
#define DISPATCH_4BITS(_src_type, _dst_type) \ | |||
DISPATCH_RAW(true, true, 8, 64, _src_type, _dst_type, char, char, 4); \ | |||
DISPATCH_RAW(true, false, 8, 64, _src_type, _dst_type, char, char, 4); \ | |||
DISPATCH_RAW(true, true, 2, 64, _src_type, _dst_type, char, char, 4); \ | |||
DISPATCH_RAW(true, false, 2, 64, _src_type, _dst_type, char, char, 4); \ | |||
DISPATCH_RAW(false, true, 8, 64, _src_type, _dst_type, char, char, 4); \ | |||
DISPATCH_RAW(false, false, 8, 64, _src_type, _dst_type, char, char, 4); \ | |||
DISPATCH_RAW(false, true, 2, 64, _src_type, _dst_type, char, char, 4); \ | |||
DISPATCH_RAW(false, false, 2, 64, _src_type, _dst_type, char, char, 4); | |||
DISPATCH_4BITS(QuantizedS4, QuantizedS4); | |||
DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); | |||
#undef DISPATCH_4BITS | |||
@@ -1521,7 +1535,6 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( | |||
"Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", | |||
stype.name(), dtype.name(), h, w); | |||
} | |||
after_kernel_launch(); | |||
} | |||
bool relayout_format::relayout_format_cuda_usable( | |||
@@ -1568,7 +1581,7 @@ void relayout_format::relayout_format_cuda_nchwx_nchw( | |||
megdnn_assert(pack_ic == 64, "Unsupport pack size(pack_ic:%d)", pack_ic); | |||
#undef DEF | |||
int n = src.layout[0]; | |||
int c = src.layout[1] * pack_ic; | |||
int ic = src.layout[1] * pack_ic; | |||
int h = src.layout[2]; | |||
// align to byte | |||
int w = src.layout[3]; | |||
@@ -1578,7 +1591,8 @@ void relayout_format::relayout_format_cuda_nchwx_nchw( | |||
int ic_stride = src_layout.stride[1]; | |||
int n_stride_dst = dst_layout.stride[0]; | |||
int oc_stride = dst_layout.stride[1]; | |||
int problem_size = n * (c / pack_ic) * hw; | |||
int problem_size = n * (ic / pack_ic) * hw; | |||
int oc = dst.layout[1]; | |||
bool same_scale = src_scale == dst_scale; | |||
bool padding = w % 2 != 0; | |||
@@ -1611,17 +1625,18 @@ void relayout_format::relayout_format_cuda_nchwx_nchw( | |||
n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(InnerDtype_)); \ | |||
oc_stride = oc_stride * _size_nbits / (8 * sizeof(InnerDtype_)); \ | |||
typename RelayoutProblem_::Param param{ \ | |||
SrcIterator_{(_src_c_type*)src.raw_ptr, ic_stride, c, w, \ | |||
SrcIterator_{(_src_c_type*)src.raw_ptr, ic_stride, ic, w, \ | |||
w_pad}, \ | |||
DstIterator_{(InnerDtype_*)dst.raw_ptr, oc_stride, c, w, \ | |||
DstIterator_{(InnerDtype_*)dst.raw_ptr, oc_stride, oc, w, \ | |||
w_pad}, \ | |||
CudaPostProcess_{src_scale, src_zero_point, dst_scale, \ | |||
dst_zero_point}, \ | |||
n_stride_src, \ | |||
n_stride_dst, \ | |||
n, \ | |||
c, \ | |||
hw}; \ | |||
ic, \ | |||
hw, \ | |||
src_zero_point}; \ | |||
auto kernel = relayout_kern<RelayoutProblem_>; \ | |||
int nr_threads = query_blocksize_for_kernel(kernel); \ | |||
nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \ | |||
@@ -1645,7 +1660,6 @@ void relayout_format::relayout_format_cuda_nchwx_nchw( | |||
megdnn_assert(false, | |||
"Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", | |||
stype.name(), dtype.name(), h, w); | |||
after_kernel_launch(); | |||
} | |||
void relayout_format::relayout_format_cuda_nchw4_nchw( | |||
@@ -21,6 +21,7 @@ | |||
#include "cuda.h" | |||
#include "src/cuda/cudnn_with_check.h" | |||
#include "cutlass/cutlass.h" | |||
#include "cutlass/platform/platform.h" | |||
#define cuda_check(_x) \ | |||
do { \ | |||
@@ -448,13 +449,12 @@ MEGDNN_DEVICE __forceinline__ static int transform_int8_to_uint4x8( | |||
template <bool signedness, typename T> | |||
MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(T storage, | |||
int bits) { | |||
uint8_t result = (uint8_t)((storage >> bits) & 0xf); | |||
if (signedness) { | |||
static constexpr uint8_t mask = (uint8_t)((1 << 4) - 1); | |||
return (result & uint8_t(1 << 3)) ? ((int)(result) | ~(int)(mask)) | |||
: (int)(result); | |||
} | |||
return int(result); | |||
static constexpr int shift = 28; | |||
using type = typename cutlass::platform::conditional<signedness, int, | |||
unsigned>::type; | |||
unsigned intermediate = static_cast<unsigned>(storage); | |||
type result = reinterpret_cast<type&>(intermediate); | |||
return (result << (shift - bits)) >> shift; | |||
} | |||
MEGDNN_DEVICE __forceinline__ static void transform_int4x8_to_int8( | |||
@@ -42,6 +42,36 @@ void recursive_cp(const TensorND& dst, const TensorND& src, size_t idx = 0, | |||
} | |||
} | |||
template <size_t size_nbits> | |||
void lowbit_recursive_cp(const TensorND& dst, const TensorND& src, | |||
size_t idx = 0, size_t src_offset = 0, | |||
size_t dst_offset = 0) { | |||
MEGDNN_STATIC_ASSERT(!(8_z % size_nbits), | |||
"size in bits of lowbit data type can only be 1, 2, 4 " | |||
"or 8"); | |||
if (idx < (src.layout.ndim - 1)) { | |||
for (size_t i = 0; i < src.layout[idx]; ++i) { | |||
lowbit_recursive_cp<size_nbits>( | |||
dst, src, idx + 1, src_offset + i * src.layout.stride[idx], | |||
dst_offset + i * dst.layout.stride[idx]); | |||
} | |||
} else { | |||
megdnn_assert(src.layout.stride[idx] == 1); | |||
megdnn_assert(dst.layout.stride[idx] == 1); | |||
size_t dim_bytes = div_ceil(src.layout[idx], 8_z / size_nbits); | |||
// offset in elements | |||
uint8_t* dptr = reinterpret_cast<uint8_t*>(dst.raw_ptr) + | |||
(dst_offset * size_nbits / 8); | |||
uint8_t* sptr = reinterpret_cast<uint8_t*>(src.raw_ptr) + | |||
(src_offset * size_nbits / 8); | |||
for (size_t i = 0; i < dim_bytes; ++i) { | |||
*dptr = *sptr; | |||
dptr++; | |||
sptr++; | |||
} | |||
} | |||
} | |||
void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src) { | |||
switch (src.layout.dtype.enumv()) { | |||
#define cb(name, ctype) \ | |||
@@ -54,10 +84,17 @@ void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src) { | |||
cb(Int32, dt_int32); | |||
cb(QuantizedS32, dt_int32); | |||
cb(QuantizedS8, dt_qint8); | |||
#undef cb | |||
#define cb(name, size_nbits) \ | |||
case (DTypeEnum::name): { \ | |||
lowbit_recursive_cp<size_nbits>(dst, src); \ | |||
break; \ | |||
} | |||
cb(QuantizedS4, 4); | |||
cb(Quantized4Asymm, 4); | |||
#undef cb | |||
default: | |||
megdnn_assert(0, "not support dtype %s", src.layout.dtype.name()); | |||
#undef cb | |||
} | |||
} | |||
@@ -66,24 +103,27 @@ void extract_from_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src, | |||
megdnn_assert(dst.layout.is_contiguous() && src.layout.is_contiguous(), | |||
"dst %s, src %s", dst.layout.to_string().c_str(), | |||
src.layout.to_string().c_str()); | |||
const size_t type_size = dst.layout.dtype.size(); | |||
const size_t n = dst.layout[0]; | |||
const size_t n_stride_dst = dst.layout.stride[0]; | |||
const size_t n_stride_src = src.layout.stride[0]; | |||
const size_t n_stride_dst_in_bytes = | |||
dst.layout.dtype.size(dst.layout.stride[0]); | |||
const size_t n_stride_src_in_bytes = | |||
src.layout.dtype.size(src.layout.stride[0]); | |||
const size_t ocpg = dst.layout[1] / group; | |||
const size_t icpg = src.layout[1] / group; | |||
const size_t dst_hw = dst.layout[2] * dst.layout[3]; | |||
const size_t src_hw = src.layout[2] * src.layout[3]; | |||
megdnn_assert(dst_hw == src_hw); | |||
const size_t dst_c_stride_in_bytes = | |||
dst.layout.dtype.size(dst.layout.stride[1]); | |||
const size_t src_c_stride_in_bytes = | |||
src.layout.dtype.size(src.layout.stride[1]); | |||
megdnn_assert(dst_c_stride_in_bytes == src_c_stride_in_bytes); | |||
for (size_t nid = 0; nid < n; ++nid) { | |||
const size_t n_offset_dst = nid * n_stride_dst * type_size; | |||
const size_t n_offset_src = nid * n_stride_src * type_size; | |||
const size_t n_offset_dst = nid * n_stride_dst_in_bytes; | |||
const size_t n_offset_src = nid * n_stride_src_in_bytes; | |||
for (size_t gid = 0; gid < group; ++gid) { | |||
memcpy((char*)dst.raw_ptr + n_offset_dst + | |||
gid * ocpg * dst_hw * type_size, | |||
gid * ocpg * dst_c_stride_in_bytes, | |||
(char*)src.raw_ptr + n_offset_src + | |||
gid * icpg * src_hw * type_size, | |||
ocpg * dst_hw * type_size); | |||
gid * icpg * src_c_stride_in_bytes, | |||
ocpg * dst_c_stride_in_bytes); | |||
} | |||
} | |||
}; | |||
@@ -415,6 +455,30 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||
return oc * ic * h * w * src.dtype.size(); | |||
} | |||
case Param::Mode::NCHW_NCHW64: { | |||
if (src[1] % 64 != 0) { | |||
size_t n = src[0]; | |||
size_t c = round_up(src[1], 64_z); | |||
size_t h = src[2]; | |||
size_t w = src[3]; | |||
TensorLayout wsly({n, c, h, w}, src.dtype); | |||
return wsly.span().dist_byte(); | |||
} | |||
return 0_z; | |||
} | |||
case Param::Mode::NCHW64_NCHW: { | |||
if (param().oc != 0) { | |||
size_t n = src[0]; | |||
size_t c = src[1] * 64; | |||
size_t h = src[2]; | |||
size_t w = src[3]; | |||
TensorLayout wsly({n, c, h, w}, dst.dtype); | |||
return wsly.span().dist_byte(); | |||
} | |||
return 0_z; | |||
} | |||
default: | |||
return 0; | |||
} | |||
@@ -437,6 +501,7 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
// clean dst | |||
MEGDNN_DISPATCH_CPU_KERN( | |||
m_handle, memset(dst.raw_ptr, 0, dst.layout.span().dist_byte())); | |||
// pre | |||
if (param().mode == Param::Mode::NCHW_NHWCD4I) { | |||
size_t N = src.layout[0]; | |||
size_t IC = src.layout[1]; | |||
@@ -551,6 +616,27 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
cb2(2, 4, NCHW_NCHW4, group_src_layout, workspace_layout); | |||
} | |||
} else if (param().mode == Param::Mode::NCHW_NCHW64) { | |||
MIDOUT_BEGIN(megdnn_naive_relayout_format, | |||
midout_iv(Param::Mode::NCHW_NCHW64)) { | |||
size_t c = src.layout[1]; | |||
if (c % 64 != 0) { | |||
uint8_t zp = 0; | |||
if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||
zp = src.layout.dtype.param<dtype::Quantized4Asymm>() | |||
.zero_point; | |||
zp = (zp & 0xf) | (zp << 4); | |||
} | |||
MEGDNN_DISPATCH_CPU_KERN( | |||
m_handle, memset(workspace.raw_ptr, zp, | |||
exec_workspace.span().dist_byte())); | |||
TensorND ws_nd(workspace.raw_ptr, exec_workspace); | |||
MEGDNN_DISPATCH_CPU_KERN(m_handle, | |||
padding_to_workspace(ws_nd, src);); | |||
exec_src_nd.raw_ptr = workspace.raw_ptr; | |||
} | |||
} | |||
MIDOUT_END(); | |||
} else if (param().mode == | |||
Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) { | |||
cb(1, 4, NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT); | |||
@@ -574,24 +660,16 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
cb(1, 2, 4, NCHW_NCHW4_WEIGHT); | |||
} | |||
} else if (param().mode == Param::Mode::NCHW4_NCHW) { | |||
if (exec_workspace.total_nr_elems() == dst.layout.total_nr_elems()) { | |||
m_handle->relayout_opr()->exec( | |||
exec_src_nd, {dst.raw_ptr, exec_workspace}, handle()); | |||
return; | |||
} else { | |||
m_handle->relayout_opr()->exec( | |||
exec_src_nd, {workspace.raw_ptr, exec_workspace}, handle()); | |||
TensorLayout workspace_layout{{src.layout[0], src.layout[1] * 4, | |||
src.layout[2], src.layout[3]}, | |||
src.layout.dtype, | |||
src.layout.format}; | |||
extract_from_workspace(exec_dst_nd, | |||
{workspace.raw_ptr, workspace_layout}, | |||
param().group); | |||
return; | |||
if (exec_workspace.total_nr_elems() != dst.layout.total_nr_elems()) { | |||
exec_dst_nd = {workspace.raw_ptr, exec_workspace}; | |||
} | |||
} else if (param().mode == Param::Mode::NCHW64_NCHW) { | |||
if (exec_workspace.total_nr_elems() != dst.layout.total_nr_elems()) { | |||
exec_dst_nd = {workspace.raw_ptr, exec_workspace}; | |||
} | |||
} | |||
// do relayout | |||
if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && | |||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { | |||
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; | |||
@@ -600,7 +678,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
do_copy_diff_qu8_q8(dst, src); | |||
}; | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); | |||
return; | |||
} else if (src.layout.dtype.enumv() == DTypeEnum::Uint8 && | |||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { | |||
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; | |||
@@ -609,7 +686,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
do_copy_diff_u8_q8(dst, src); | |||
}; | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); | |||
return; | |||
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { | |||
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; | |||
@@ -618,7 +694,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
do_copy_diff_q8_q8(dst, src); | |||
}; | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); | |||
return; | |||
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS32 && | |||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS32) { | |||
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; | |||
@@ -627,7 +702,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
do_copy_diff_q32_q32(dst, src); | |||
}; | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); | |||
return; | |||
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4 && | |||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS4) { | |||
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; | |||
@@ -636,7 +710,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
do_copy_diff_q4_q4(dst, src); | |||
}; | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); | |||
return; | |||
} else if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm && | |||
dst.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; | |||
@@ -645,9 +718,20 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
do_copy_diff_qu4_qu4(dst, src); | |||
}; | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); | |||
return; | |||
} else { | |||
m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); | |||
} | |||
// post | |||
if (param().mode == Param::Mode::NCHW4_NCHW || | |||
param().mode == Param::Mode::NCHW64_NCHW) { | |||
if (exec_workspace.total_nr_elems() != dst.layout.total_nr_elems()) { | |||
megdnn_assert(exec_workspace.dtype == dst.layout.dtype); | |||
TensorND ws_nd{workspace.raw_ptr, exec_workspace}; | |||
MEGDNN_DISPATCH_CPU_KERN( | |||
m_handle, | |||
extract_from_workspace(dst, ws_nd, param().group);); | |||
} | |||
} | |||
#undef cb | |||
} | |||
@@ -18,7 +18,6 @@ | |||
using namespace megdnn; | |||
using namespace test; | |||
#define MEGDNN_WITH_BENCHMARK 1 | |||
TEST_F(CUDA, RELAYOUT_FORMAT) { | |||
Checker<RelayoutFormat> checker(handle_cuda()); | |||
@@ -245,7 +244,7 @@ TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW64) { | |||
param::RelayoutFormat param; | |||
param.mode = param::RelayoutFormat::Mode::NCHW_NCHW64; | |||
for (size_t n : {1, 3}) { | |||
for (size_t c : {64, 128}) { | |||
for (size_t c : {15, 64, 128}) { | |||
for (size_t h : {7, 14, 16, 28}) { | |||
for (size_t w : {2, 3, 7, 8, 16, 31}) { | |||
checker.set_dtype(0, dtype::QuantizedS4{2.f}) | |||
@@ -285,36 +284,41 @@ TEST_F(CUDA, RELAYOUT_FORMAT_NCHW64_NCHW) { | |||
param::RelayoutFormat param; | |||
param.mode = param::RelayoutFormat::Mode::NCHW64_NCHW; | |||
for (size_t n : {1, 3}) { | |||
for (size_t c : {64, 128}) { | |||
for (size_t c : {15, 64, 128}) { | |||
for (size_t h : {7, 14, 16, 28}) { | |||
for (size_t w : {2, 3, 4, 7, 14, 16, 17}) { | |||
if (c % 64 != 0) { | |||
param.oc = c; | |||
} else { | |||
param.oc = 0; | |||
} | |||
checker.set_dtype(0, dtype::QuantizedS4{2.f}) | |||
.set_dtype(1, dtype::QuantizedS4{2.f}) | |||
.set_rng(0, &s4) | |||
.set_param(param) | |||
.set_epsilon(1e-3) | |||
.execs({{n, c / 64, h, w, 64}, {}}); | |||
.execs({{n, (c + 63) / 64, h, w, 64}, {}}); | |||
checker.set_dtype(0, dtype::Quantized4Asymm{1.2f, 4}) | |||
.set_dtype(1, dtype::Quantized4Asymm{1.2f, 8}) | |||
.set_rng(0, &u4) | |||
.set_param(param) | |||
.set_epsilon(1e-3) | |||
.execs({{n, c / 64, h, w, 64}, {}}); | |||
.execs({{n, (c + 63) / 64, h, w, 64}, {}}); | |||
checker.set_dtype(0, dtype::QuantizedS4{1.19990307f}) | |||
.set_dtype(1, dtype::QuantizedS4{1.f}) | |||
.set_rng(0, &s4) | |||
.set_param(param) | |||
.set_epsilon(1e-3) | |||
.execs({{n, c / 64, h, w, 64}, {}}); | |||
.execs({{n, (c + 63) / 64, h, w, 64}, {}}); | |||
checker.set_dtype(0, dtype::Quantized4Asymm{1.20211209f, 8}) | |||
.set_dtype(1, dtype::Quantized4Asymm{1.f, 4}) | |||
.set_rng(0, &u4) | |||
.set_param(param) | |||
.set_epsilon(1e-3) | |||
.execs({{n, c / 64, h, w, 64}, {}}); | |||
.execs({{n, (c + 63) / 64, h, w, 64}, {}}); | |||
} | |||
} | |||
} | |||
@@ -375,10 +379,14 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) { | |||
CUBenchmarker<RelayoutFormat> benchmarker(handle_cuda()); | |||
benchmarker.set_param(param); | |||
benchmarker.set_dtype(0, dtype::QuantizedS4{1.19990307f}) | |||
.set_dtype(1, dtype::QuantizedS4{1.20210322f}); | |||
.set_dtype(1, dtype::QuantizedS4{1.19990307f}); | |||
for (auto&& shape : shapes) { | |||
double memaccess = double(shape.total_nr_elems()) * 1e-6; | |||
double memaccess = | |||
double(TensorLayout(shape, dtype::QuantizedS4{1.f}) | |||
.span() | |||
.dist_byte()) * | |||
2e-6; | |||
auto time_ms = benchmarker.execs({shape, {}}); | |||
printf("execute %s, time %.4f ms, %.4f GB/s\n", | |||
shape.to_string().c_str(), time_ms, memaccess / time_ms); | |||
@@ -387,8 +395,9 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) { | |||
{ | |||
TensorShapeArray shapes = { | |||
{1, 64, 56, 56}, {16, 64, 56, 56}, {64, 64, 56, 56}, | |||
{1, 64, 56, 55}, {16, 64, 56, 55}, {64, 64, 56, 55}, | |||
{1, 64, 56, 56}, {16, 64, 56, 56}, {64, 64, 56, 56}, | |||
{1, 64, 56, 55}, {16, 64, 56, 55}, {64, 64, 56, 55}, | |||
{1, 256, 384, 640}, | |||
}; | |||
Param param; | |||
param.mode = param::RelayoutFormat::Mode::NCHW_NCHW64; | |||
@@ -399,7 +408,8 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) { | |||
{64, 1, 56, 56, 64}, | |||
{1, 32, 7, 7, 64}, | |||
{16, 32, 7, 7, 64}, | |||
{64, 32, 7, 7, 64}, | |||
{64, 32, 7, 7, 64}, | |||
{1, 4, 384, 640, 64}, | |||
}; | |||
Param param; | |||
param.mode = param::RelayoutFormat::Mode::NCHW64_NCHW; | |||