GitOrigin-RevId: 07f2ee6c5b
release-0.5
@@ -434,7 +434,7 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) | |||||
'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'), | 'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'), | ||||
Doc('MK8', 'Split 8 from M and K, better for neon compute:' | Doc('MK8', 'Split 8 from M and K, better for neon compute:' | ||||
'(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the ' | '(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the ' | ||||
'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), | |||||
'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), | |||||
Doc('MK4_DOT', 'Split 4 from M and K, better for neon dotprod:' | Doc('MK4_DOT', 'Split 4 from M and K, better for neon dotprod:' | ||||
'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' | 'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' | ||||
'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))')) | 'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))')) | ||||
@@ -858,7 +858,10 @@ when the ``I`` suffix is present. | |||||
'NCHW_NCHW88_CONV_CHAN_WEIGHT', | 'NCHW_NCHW88_CONV_CHAN_WEIGHT', | ||||
'NCHW_NCHW88_CONV_GROUP_WEIGHT', | 'NCHW_NCHW88_CONV_GROUP_WEIGHT', | ||||
'NCHW_NCHW88', | 'NCHW_NCHW88', | ||||
'NCHW88_NCHW') | |||||
'NCHW88_NCHW', | |||||
'NCHW_NCHW4_IC_SMALL', | |||||
'NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT', | |||||
) | |||||
) | ) | ||||
@@ -28,6 +28,26 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, | |||||
dst[3] = src[3]; | dst[3] = src[3]; | ||||
dst[4] = 4; | dst[4] = 4; | ||||
break; | break; | ||||
case Param::Mode::NCHW_NCHW4_IC_SMALL: | |||||
dst.ndim = 5; | |||||
megdnn_assert(src[1] <= 4_z, "ic should be less equal 4"); | |||||
dst[0] = src[0]; | |||||
dst[1] = div_ceil(src[1], 4_z); | |||||
dst[2] = src[2]; | |||||
dst[3] = src[3]; | |||||
dst[4] = 4; | |||||
break; | |||||
case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: | |||||
megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 4"); | |||||
megdnn_assert(src[1] <= 4_z, "ic should be less equal 4"); | |||||
dst.ndim = 5; | |||||
dst[0] = src[0]; | |||||
dst[1] = div_ceil(src[1], 4_z); | |||||
dst[2] = src[2]; | |||||
dst[3] = src[3]; | |||||
dst[4] = 4; | |||||
break; | |||||
case Param::Mode::NCHW_NCHW88: | case Param::Mode::NCHW_NCHW88: | ||||
dst.ndim = 5; | dst.ndim = 5; | ||||
dst[0] = src[0]; | dst[0] = src[0]; | ||||
@@ -276,6 +296,8 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { | |||||
case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: | case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: | ||||
case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: | case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: | ||||
case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: | case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: | ||||
case Param::Mode::NCHW_NCHW4_IC_SMALL: | |||||
case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: | |||||
CHECK_SRC(DefaultTensorFormat::make()); | CHECK_SRC(DefaultTensorFormat::make()); | ||||
dst = src; | dst = src; | ||||
break; | break; | ||||
@@ -374,6 +396,23 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, | |||||
exec_dst = dst; | exec_dst = dst; | ||||
} | } | ||||
break; | break; | ||||
case Param::Mode::NCHW_NCHW4_IC_SMALL: | |||||
case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: | |||||
// nchw to nchw4c or oihw to oihw4i | |||||
{ | |||||
TensorLayout work_space_layout( | |||||
{src[0], round_up(src[1], 4_z), src[2], src[3]}, | |||||
src.dtype, src.format); | |||||
exec_src = work_space_layout | |||||
.reshape({src[0], div_ceil(src[1], 4_z), 4, | |||||
src[2], src[3]}) | |||||
.dimshuffle({0, 1, 3, 4, 2}); | |||||
exec_dst = dst; | |||||
} | |||||
break; | |||||
case Param::Mode::NCHW_NHWCD4: | case Param::Mode::NCHW_NHWCD4: | ||||
case Param::Mode::NCHW_NHWCD4I: | case Param::Mode::NCHW_NHWCD4I: | ||||
// src is {N, C, H, W} | // src is {N, C, H, W} | ||||
@@ -11,6 +11,7 @@ | |||||
#include "src/cuda/relayout_format/opr_impl.h" | #include "src/cuda/relayout_format/opr_impl.h" | ||||
#include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
#include "src/cuda/utils.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -20,15 +21,22 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
auto src_dtype = src.layout.dtype; | auto src_dtype = src.layout.dtype; | ||||
megdnn_assert( | megdnn_assert( | ||||
param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || | param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || | ||||
param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4, | |||||
param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4 || | |||||
param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL || | |||||
param().mode == | |||||
Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT, | |||||
"relayout format of cuda only support NCHW4->CHWN4 or " | "relayout format of cuda only support NCHW4->CHWN4 or " | ||||
"CHWN4->NCHW4"); | |||||
if (src_dtype.enumv() == DTypeEnum::QuantizedS8) { | |||||
"CHWN4->NCHW4 or NCHW->NCHW4"); | |||||
if ((param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || | |||||
param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4) && | |||||
src_dtype.enumv() == DTypeEnum::QuantizedS8) { | |||||
size_t row = 0, col = 0; | size_t row = 0, col = 0; | ||||
if (param().mode == Param::RelayoutFormat::Mode::NCHW4_CHWN4) { | if (param().mode == Param::RelayoutFormat::Mode::NCHW4_CHWN4) { | ||||
row = src.layout[0], | row = src.layout[0], | ||||
col = src.layout[1] * src.layout[2] * src.layout[3]; | col = src.layout[1] * src.layout[2] * src.layout[3]; | ||||
} else { | } else { | ||||
megdnn_assert(param().mode == | |||||
param::RelayoutFormat::Mode::CHWN4_NCHW4); | |||||
row = src.layout[0] * src.layout[1] * src.layout[2], | row = src.layout[0] * src.layout[1] * src.layout[2], | ||||
col = src.layout[3]; | col = src.layout[3]; | ||||
} | } | ||||
@@ -43,6 +51,27 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
return handle()->create_operator<RelayoutForward>()->exec(trans_in, | return handle()->create_operator<RelayoutForward>()->exec(trans_in, | ||||
trans_out); | trans_out); | ||||
} | } | ||||
if ((param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL || | |||||
param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) && | |||||
src.layout[1] % 4 != 0) { | |||||
megdnn_assert(src.raw_ptr != dst.raw_ptr && src.layout.ndim == 4, | |||||
"The mode of NCHW_NCHW4 and NCHW_NCHW4_CONV_DENSE_WEIGHT " | |||||
"of RelayoutFormat opr(cuda backend) does not support " | |||||
"src.ptr == dst.ptr"); | |||||
megdnn_assert(src.layout[1] <= 4); | |||||
cuda_check(cudaMemsetAsync(dst.raw_ptr, 0, | |||||
dst.layout.span().dist_byte(), | |||||
cuda_stream(this->handle()))); | |||||
TensorLayout exec_dst_layout = dst.layout; | |||||
exec_dst_layout[4] = src.layout[1]; | |||||
TensorLayout exec_src_layout = | |||||
src.layout | |||||
.reshape({src.layout[0], src.layout[1], 1, | |||||
src.layout[2], src.layout[3]}) | |||||
.dimshuffle({0, 2, 3, 4, 1}); | |||||
return handle()->create_operator<RelayoutForward>()->exec( | |||||
{src.raw_ptr, exec_src_layout}, {dst.raw_ptr, exec_dst_layout}); | |||||
} | |||||
TensorLayout exec_src, exec_dst; | TensorLayout exec_src, exec_dst; | ||||
deduce_exec_layout(src.layout, dst.layout, exec_src, exec_dst); | deduce_exec_layout(src.layout, dst.layout, exec_src, exec_dst); | ||||
TensorND exec_src_nd{src.raw_ptr, exec_src}; | TensorND exec_src_nd{src.raw_ptr, exec_src}; | ||||
@@ -79,6 +79,7 @@ void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src, | |||||
} | } | ||||
cb(Float32, dt_float32); | cb(Float32, dt_float32); | ||||
cb(QuantizedS8, dt_qint8); | |||||
default: | default: | ||||
megdnn_assert(0); | megdnn_assert(0); | ||||
#undef cb | #undef cb | ||||
@@ -138,7 +139,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
return n * c * h * w * src.dtype.size(); | return n * c * h * w * src.dtype.size(); | ||||
} | } | ||||
case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: { | case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: { | ||||
megdnn_assert(src.ndim == 4, "src must be oihw ,nmdim == 5"); | |||||
megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 5"); | |||||
megdnn_assert(src[0] % 8 == 0, | megdnn_assert(src[0] % 8 == 0, | ||||
"NCHW_NCHW88_CONV_DENSE_WEIGHT oc must align to 8"); | "NCHW_NCHW88_CONV_DENSE_WEIGHT oc must align to 8"); | ||||
if (src[1] % 8 == 0) | if (src[1] % 8 == 0) | ||||
@@ -150,7 +151,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
return oc * ic * h * w * src.dtype.size(); | return oc * ic * h * w * src.dtype.size(); | ||||
} | } | ||||
case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: { | case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: { | ||||
megdnn_assert(src.ndim == 5, "src must be goihw ,nmdim == 5"); | |||||
megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5"); | |||||
megdnn_assert(src[1] % 8 == 0, | megdnn_assert(src[1] % 8 == 0, | ||||
"NCHW_NCHW88_CONV_CHAN_WEIGHT oc per group must " | "NCHW_NCHW88_CONV_CHAN_WEIGHT oc per group must " | ||||
"align to 8"); | "align to 8"); | ||||
@@ -164,7 +165,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
return group * ocpg * icpg * h * w * src.dtype.size(); | return group * ocpg * icpg * h * w * src.dtype.size(); | ||||
} | } | ||||
case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: { | case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: { | ||||
megdnn_assert(src.ndim == 5, "src must be goihw ,nmdim == 5"); | |||||
megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5"); | |||||
if (src[0] % 8 == 0) | if (src[0] % 8 == 0) | ||||
return 0; | return 0; | ||||
size_t group = round_up(src[0], 8_z); | size_t group = round_up(src[0], 8_z); | ||||
@@ -174,6 +175,27 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
size_t w = src[4]; | size_t w = src[4]; | ||||
return group * ocpg * icpg * h * w * src.dtype.size(); | return group * ocpg * icpg * h * w * src.dtype.size(); | ||||
} | } | ||||
case Param::Mode::NCHW_NCHW4_IC_SMALL: { | |||||
if (src[1] % 4 == 0) | |||||
return 0; | |||||
size_t n = src[0]; | |||||
size_t c = round_up(src[1], 4_z); | |||||
size_t h = src[2]; | |||||
size_t w = src[3]; | |||||
return n * c * h * w * src.dtype.size(); | |||||
} | |||||
case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: { | |||||
megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 5"); | |||||
if (src[1] % 4 == 0) | |||||
return 0; | |||||
size_t oc = src[0]; | |||||
size_t ic = round_up(src[1], 4_z); | |||||
size_t h = src[2]; | |||||
size_t w = src[3]; | |||||
return oc * ic * h * w * src.dtype.size(); | |||||
} | |||||
default: | default: | ||||
return 0; | return 0; | ||||
} | } | ||||
@@ -244,31 +266,28 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
exec_src_nd.raw_ptr = workspace.raw_ptr; | exec_src_nd.raw_ptr = workspace.raw_ptr; | ||||
} | } | ||||
} else if (param().mode == Param::Mode::NCHW_NCHW88) { | } else if (param().mode == Param::Mode::NCHW_NCHW88) { | ||||
size_t ic = src.layout[1]; | |||||
if (ic % 8 != 0) { | |||||
padding_to_workspace({workspace.raw_ptr, exec_src}, src, 1, 8); | |||||
exec_src_nd.raw_ptr = workspace.raw_ptr; | |||||
} | |||||
#define cb(_idx, _pack_size) \ | |||||
size_t val = src.layout[_idx]; \ | |||||
if (val % _pack_size != 0) { \ | |||||
padding_to_workspace({workspace.raw_ptr, exec_src}, src, _idx, \ | |||||
_pack_size); \ | |||||
exec_src_nd.raw_ptr = workspace.raw_ptr; \ | |||||
} | |||||
cb(1, 8); | |||||
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT) { | } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT) { | ||||
megdnn_assert(src.layout[0] % 8 == 0); | megdnn_assert(src.layout[0] % 8 == 0); | ||||
size_t ic = src.layout[1]; | |||||
if (ic % 8 != 0) { | |||||
padding_to_workspace({workspace.raw_ptr, exec_src}, src, 1, 8_z); | |||||
exec_src_nd.raw_ptr = workspace.raw_ptr; | |||||
} | |||||
cb(1, 8); | |||||
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT) { | } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT) { | ||||
size_t group = src.layout[0]; | |||||
if (group % 8 != 0) { | |||||
padding_to_workspace({workspace.raw_ptr, exec_src}, src, 0, 8_z); | |||||
exec_src_nd.raw_ptr = workspace.raw_ptr; | |||||
} | |||||
cb(0, 8); | |||||
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT) { | } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT) { | ||||
megdnn_assert(src.layout[1] % 8 == 0); | megdnn_assert(src.layout[1] % 8 == 0); | ||||
size_t ic = src.layout[2]; | |||||
if (ic % 8 != 0) { | |||||
padding_to_workspace({workspace.raw_ptr, exec_src}, src, 2, 8_z); | |||||
exec_src_nd.raw_ptr = workspace.raw_ptr; | |||||
} | |||||
cb(2, 8); | |||||
} else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL) { | |||||
cb(1, 4); | |||||
} else if (param().mode == | |||||
Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) { | |||||
cb(1, 4); | |||||
} | } | ||||
m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); | m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); | ||||
} | } | ||||
@@ -8,6 +8,7 @@ | |||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "megdnn/dtype.h" | |||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "test/common/checker.h" | #include "test/common/checker.h" | ||||
#include "test/common/rng.h" | #include "test/common/rng.h" | ||||
@@ -30,4 +31,25 @@ TEST_F(CUDA, RELAYOUT_FORMAT) { | |||||
checker.execs({{22, 23, 24, 25, 4}, {}}); | checker.execs({{22, 23, 24, 25, 4}, {}}); | ||||
} | } | ||||
TEST_F(CUDA, RELAYOUT_FORMAT_NCHW4) { | |||||
Checker<RelayoutFormat> checker(handle_cuda()); | |||||
UniformIntRNG rng{-50, 50}; | |||||
param::RelayoutFormat param; | |||||
param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL; | |||||
for (DType dtype : | |||||
std::vector<DType>({dtype::QuantizedS8{0.1f}, dtype::Float32{}})) { | |||||
checker.set_dtype(0, dtype).set_rng(0, &rng); | |||||
checker.set_param(param).execs({{2, 4, 35, 36}, {}}); | |||||
checker.set_param(param).execs({{2, 3, 35, 36}, {}}); | |||||
checker.set_param(param).execs({{2, 1, 35, 36}, {}}); | |||||
param.mode = param::RelayoutFormat::Mode:: | |||||
NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT; | |||||
checker.set_param(param).execs({{4, 3, 3, 3}, {}}); | |||||
checker.set_param(param).execs({{4, 4, 3, 3}, {}}); | |||||
checker.set_param(param).execs({{1, 4, 3, 3}, {}}); | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |