GitOrigin-RevId: 07f2ee6c5b
release-0.5
@@ -434,7 +434,7 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) | |||
'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'), | |||
Doc('MK8', 'Split 8 from M and K, better for neon compute:' | |||
'(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the ' | |||
'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), | |||
'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), | |||
Doc('MK4_DOT', 'Split 4 from M and K, better for neon dotprod:' | |||
'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' | |||
'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))')) | |||
@@ -858,7 +858,10 @@ when the ``I`` suffix is present. | |||
'NCHW_NCHW88_CONV_CHAN_WEIGHT', | |||
'NCHW_NCHW88_CONV_GROUP_WEIGHT', | |||
'NCHW_NCHW88', | |||
'NCHW88_NCHW') | |||
'NCHW88_NCHW', | |||
'NCHW_NCHW4_IC_SMALL', | |||
'NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT', | |||
) | |||
) | |||
@@ -28,6 +28,26 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, | |||
dst[3] = src[3]; | |||
dst[4] = 4; | |||
break; | |||
case Param::Mode::NCHW_NCHW4_IC_SMALL: | |||
dst.ndim = 5; | |||
megdnn_assert(src[1] <= 4_z, "ic should be less equal 4"); | |||
dst[0] = src[0]; | |||
dst[1] = div_ceil(src[1], 4_z); | |||
dst[2] = src[2]; | |||
dst[3] = src[3]; | |||
dst[4] = 4; | |||
break; | |||
case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: | |||
megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 4"); | |||
megdnn_assert(src[1] <= 4_z, "ic should be less equal 4"); | |||
dst.ndim = 5; | |||
dst[0] = src[0]; | |||
dst[1] = div_ceil(src[1], 4_z); | |||
dst[2] = src[2]; | |||
dst[3] = src[3]; | |||
dst[4] = 4; | |||
break; | |||
case Param::Mode::NCHW_NCHW88: | |||
dst.ndim = 5; | |||
dst[0] = src[0]; | |||
@@ -276,6 +296,8 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { | |||
case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: | |||
case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: | |||
case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: | |||
case Param::Mode::NCHW_NCHW4_IC_SMALL: | |||
case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: | |||
CHECK_SRC(DefaultTensorFormat::make()); | |||
dst = src; | |||
break; | |||
@@ -374,6 +396,23 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, | |||
exec_dst = dst; | |||
} | |||
break; | |||
case Param::Mode::NCHW_NCHW4_IC_SMALL: | |||
case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: | |||
// nchw to nchw4c or oihw to oihw4i | |||
{ | |||
TensorLayout work_space_layout( | |||
{src[0], round_up(src[1], 4_z), src[2], src[3]}, | |||
src.dtype, src.format); | |||
exec_src = work_space_layout | |||
.reshape({src[0], div_ceil(src[1], 4_z), 4, | |||
src[2], src[3]}) | |||
.dimshuffle({0, 1, 3, 4, 2}); | |||
exec_dst = dst; | |||
} | |||
break; | |||
case Param::Mode::NCHW_NHWCD4: | |||
case Param::Mode::NCHW_NHWCD4I: | |||
// src is {N, C, H, W} | |||
@@ -11,6 +11,7 @@ | |||
#include "src/cuda/relayout_format/opr_impl.h" | |||
#include "src/cuda/handle.h" | |||
#include "src/cuda/utils.h" | |||
using namespace megdnn; | |||
using namespace cuda; | |||
@@ -20,15 +21,22 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
auto src_dtype = src.layout.dtype; | |||
megdnn_assert( | |||
param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || | |||
param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4, | |||
param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4 || | |||
param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL || | |||
param().mode == | |||
Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT, | |||
"relayout format of cuda only support NCHW4->CHWN4 or " | |||
"CHWN4->NCHW4"); | |||
if (src_dtype.enumv() == DTypeEnum::QuantizedS8) { | |||
"CHWN4->NCHW4 or NCHW->NCHW4"); | |||
if ((param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || | |||
param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4) && | |||
src_dtype.enumv() == DTypeEnum::QuantizedS8) { | |||
size_t row = 0, col = 0; | |||
if (param().mode == Param::RelayoutFormat::Mode::NCHW4_CHWN4) { | |||
row = src.layout[0], | |||
col = src.layout[1] * src.layout[2] * src.layout[3]; | |||
} else { | |||
megdnn_assert(param().mode == | |||
param::RelayoutFormat::Mode::CHWN4_NCHW4); | |||
row = src.layout[0] * src.layout[1] * src.layout[2], | |||
col = src.layout[3]; | |||
} | |||
@@ -43,6 +51,27 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
return handle()->create_operator<RelayoutForward>()->exec(trans_in, | |||
trans_out); | |||
} | |||
if ((param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL || | |||
param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) && | |||
src.layout[1] % 4 != 0) { | |||
megdnn_assert(src.raw_ptr != dst.raw_ptr && src.layout.ndim == 4, | |||
"The mode of NCHW_NCHW4 and NCHW_NCHW4_CONV_DENSE_WEIGHT " | |||
"of RelayoutFormat opr(cuda backend) does not support " | |||
"src.ptr == dst.ptr"); | |||
megdnn_assert(src.layout[1] <= 4); | |||
cuda_check(cudaMemsetAsync(dst.raw_ptr, 0, | |||
dst.layout.span().dist_byte(), | |||
cuda_stream(this->handle()))); | |||
TensorLayout exec_dst_layout = dst.layout; | |||
exec_dst_layout[4] = src.layout[1]; | |||
TensorLayout exec_src_layout = | |||
src.layout | |||
.reshape({src.layout[0], src.layout[1], 1, | |||
src.layout[2], src.layout[3]}) | |||
.dimshuffle({0, 2, 3, 4, 1}); | |||
return handle()->create_operator<RelayoutForward>()->exec( | |||
{src.raw_ptr, exec_src_layout}, {dst.raw_ptr, exec_dst_layout}); | |||
} | |||
TensorLayout exec_src, exec_dst; | |||
deduce_exec_layout(src.layout, dst.layout, exec_src, exec_dst); | |||
TensorND exec_src_nd{src.raw_ptr, exec_src}; | |||
@@ -79,6 +79,7 @@ void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src, | |||
} | |||
cb(Float32, dt_float32); | |||
cb(QuantizedS8, dt_qint8); | |||
default: | |||
megdnn_assert(0); | |||
#undef cb | |||
@@ -138,7 +139,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||
return n * c * h * w * src.dtype.size(); | |||
} | |||
case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: { | |||
megdnn_assert(src.ndim == 4, "src must be oihw ,nmdim == 5"); | |||
megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 5"); | |||
megdnn_assert(src[0] % 8 == 0, | |||
"NCHW_NCHW88_CONV_DENSE_WEIGHT oc must align to 8"); | |||
if (src[1] % 8 == 0) | |||
@@ -150,7 +151,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||
return oc * ic * h * w * src.dtype.size(); | |||
} | |||
case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: { | |||
megdnn_assert(src.ndim == 5, "src must be goihw ,nmdim == 5"); | |||
megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5"); | |||
megdnn_assert(src[1] % 8 == 0, | |||
"NCHW_NCHW88_CONV_CHAN_WEIGHT oc per group must " | |||
"align to 8"); | |||
@@ -164,7 +165,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||
return group * ocpg * icpg * h * w * src.dtype.size(); | |||
} | |||
case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: { | |||
megdnn_assert(src.ndim == 5, "src must be goihw ,nmdim == 5"); | |||
megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5"); | |||
if (src[0] % 8 == 0) | |||
return 0; | |||
size_t group = round_up(src[0], 8_z); | |||
@@ -174,6 +175,27 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||
size_t w = src[4]; | |||
return group * ocpg * icpg * h * w * src.dtype.size(); | |||
} | |||
case Param::Mode::NCHW_NCHW4_IC_SMALL: { | |||
if (src[1] % 4 == 0) | |||
return 0; | |||
size_t n = src[0]; | |||
size_t c = round_up(src[1], 4_z); | |||
size_t h = src[2]; | |||
size_t w = src[3]; | |||
return n * c * h * w * src.dtype.size(); | |||
} | |||
case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: { | |||
megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 5"); | |||
if (src[1] % 4 == 0) | |||
return 0; | |||
size_t oc = src[0]; | |||
size_t ic = round_up(src[1], 4_z); | |||
size_t h = src[2]; | |||
size_t w = src[3]; | |||
return oc * ic * h * w * src.dtype.size(); | |||
} | |||
default: | |||
return 0; | |||
} | |||
@@ -244,31 +266,28 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
exec_src_nd.raw_ptr = workspace.raw_ptr; | |||
} | |||
} else if (param().mode == Param::Mode::NCHW_NCHW88) { | |||
size_t ic = src.layout[1]; | |||
if (ic % 8 != 0) { | |||
padding_to_workspace({workspace.raw_ptr, exec_src}, src, 1, 8); | |||
exec_src_nd.raw_ptr = workspace.raw_ptr; | |||
} | |||
#define cb(_idx, _pack_size) \ | |||
size_t val = src.layout[_idx]; \ | |||
if (val % _pack_size != 0) { \ | |||
padding_to_workspace({workspace.raw_ptr, exec_src}, src, _idx, \ | |||
_pack_size); \ | |||
exec_src_nd.raw_ptr = workspace.raw_ptr; \ | |||
} | |||
cb(1, 8); | |||
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT) { | |||
megdnn_assert(src.layout[0] % 8 == 0); | |||
size_t ic = src.layout[1]; | |||
if (ic % 8 != 0) { | |||
padding_to_workspace({workspace.raw_ptr, exec_src}, src, 1, 8_z); | |||
exec_src_nd.raw_ptr = workspace.raw_ptr; | |||
} | |||
cb(1, 8); | |||
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT) { | |||
size_t group = src.layout[0]; | |||
if (group % 8 != 0) { | |||
padding_to_workspace({workspace.raw_ptr, exec_src}, src, 0, 8_z); | |||
exec_src_nd.raw_ptr = workspace.raw_ptr; | |||
} | |||
cb(0, 8); | |||
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT) { | |||
megdnn_assert(src.layout[1] % 8 == 0); | |||
size_t ic = src.layout[2]; | |||
if (ic % 8 != 0) { | |||
padding_to_workspace({workspace.raw_ptr, exec_src}, src, 2, 8_z); | |||
exec_src_nd.raw_ptr = workspace.raw_ptr; | |||
} | |||
cb(2, 8); | |||
} else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL) { | |||
cb(1, 4); | |||
} else if (param().mode == | |||
Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) { | |||
cb(1, 4); | |||
} | |||
m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); | |||
} | |||
@@ -8,6 +8,7 @@ | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megdnn/dtype.h" | |||
#include "megdnn/oprs.h" | |||
#include "test/common/checker.h" | |||
#include "test/common/rng.h" | |||
@@ -30,4 +31,25 @@ TEST_F(CUDA, RELAYOUT_FORMAT) { | |||
checker.execs({{22, 23, 24, 25, 4}, {}}); | |||
} | |||
TEST_F(CUDA, RELAYOUT_FORMAT_NCHW4) { | |||
Checker<RelayoutFormat> checker(handle_cuda()); | |||
UniformIntRNG rng{-50, 50}; | |||
param::RelayoutFormat param; | |||
param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL; | |||
for (DType dtype : | |||
std::vector<DType>({dtype::QuantizedS8{0.1f}, dtype::Float32{}})) { | |||
checker.set_dtype(0, dtype).set_rng(0, &rng); | |||
checker.set_param(param).execs({{2, 4, 35, 36}, {}}); | |||
checker.set_param(param).execs({{2, 3, 35, 36}, {}}); | |||
checker.set_param(param).execs({{2, 1, 35, 36}, {}}); | |||
param.mode = param::RelayoutFormat::Mode:: | |||
NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT; | |||
checker.set_param(param).execs({{4, 3, 3, 3}, {}}); | |||
checker.set_param(param).execs({{4, 4, 3, 3}, {}}); | |||
checker.set_param(param).execs({{1, 4, 3, 3}, {}}); | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |