|
|
@@ -14,6 +14,10 @@ |
|
|
|
|
|
|
|
#include "megdnn/tensor_iter.h" |
|
|
|
|
|
|
|
#include "midout.h" |
|
|
|
|
|
|
|
MIDOUT_DECL(megdnn_naive_relayout_format) |
|
|
|
|
|
|
|
using namespace megdnn; |
|
|
|
using namespace naive; |
|
|
|
|
|
|
@@ -222,14 +226,18 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, |
|
|
|
//! ic % 4 != 0 |
|
|
|
if ((IC & 0x3)) { |
|
|
|
switch (src.layout.dtype.enumv()) { |
|
|
|
#define cb(name, ctype) \ |
|
|
|
case (DTypeEnum::name): { \ |
|
|
|
ctype* sptr = src.compatible_ptr<ctype>(); \ |
|
|
|
ctype* dptr = workspace.ptr<ctype>(); \ |
|
|
|
MEGDNN_DISPATCH_CPU_KERN( \ |
|
|
|
m_handle, \ |
|
|
|
padding_src_to_workspace<ctype>(dptr, sptr, N, IC, IH, IW);); \ |
|
|
|
break; \ |
|
|
|
#define cb(name, ctype) \ |
|
|
|
case (DTypeEnum::name): { \ |
|
|
|
MIDOUT_BEGIN(megdnn_naive_relayout_format, ctype, \ |
|
|
|
midout_iv(Param::Mode::NCHW_NHWCD4I)) { \ |
|
|
|
ctype* sptr = src.compatible_ptr<ctype>(); \ |
|
|
|
ctype* dptr = workspace.ptr<ctype>(); \ |
|
|
|
MEGDNN_DISPATCH_CPU_KERN( \ |
|
|
|
m_handle, padding_src_to_workspace<ctype>(dptr, sptr, N, \ |
|
|
|
IC, IH, IW);); \ |
|
|
|
} \ |
|
|
|
MIDOUT_END(); \ |
|
|
|
break; \ |
|
|
|
} |
|
|
|
cb(Float32, dt_float32); |
|
|
|
MEGDNN_INC_FLOAT16(cb(Float16, dt_float16)); |
|
|
@@ -248,14 +256,18 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, |
|
|
|
size_t FW = src.layout[3]; |
|
|
|
if ((IC & 0x3)) { |
|
|
|
switch (src.layout.dtype.enumv()) { |
|
|
|
#define cb(name, ctype) \ |
|
|
|
case (DTypeEnum::name): { \ |
|
|
|
ctype* sptr = src.compatible_ptr<ctype>(); \ |
|
|
|
ctype* dptr = workspace.ptr<ctype>(); \ |
|
|
|
MEGDNN_DISPATCH_CPU_KERN( \ |
|
|
|
m_handle, padding_filter_to_workspace<ctype>(dptr, sptr, OC, \ |
|
|
|
IC, FH, FW);); \ |
|
|
|
break; \ |
|
|
|
#define cb(name, ctype) \ |
|
|
|
case (DTypeEnum::name): { \ |
|
|
|
MIDOUT_BEGIN(megdnn_naive_relayout_format, ctype, \ |
|
|
|
midout_iv(Param::Mode::INTER_WEIGHT_DENSEI_DOT)) { \ |
|
|
|
ctype* sptr = src.compatible_ptr<ctype>(); \ |
|
|
|
ctype* dptr = workspace.ptr<ctype>(); \ |
|
|
|
MEGDNN_DISPATCH_CPU_KERN(m_handle, \ |
|
|
|
padding_filter_to_workspace<ctype>( \ |
|
|
|
dptr, sptr, OC, IC, FH, FW);); \ |
|
|
|
} \ |
|
|
|
MIDOUT_END(); \ |
|
|
|
break; \ |
|
|
|
} |
|
|
|
cb(Quantized8Asymm, dt_uint8); |
|
|
|
cb(QuantizedS8, dt_int8); |
|
|
@@ -266,30 +278,35 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, |
|
|
|
exec_src_nd.raw_ptr = workspace.raw_ptr; |
|
|
|
} |
|
|
|
} else if (param().mode == Param::Mode::NCHW_NCHW88) { |
|
|
|
#define cb(_idx, _pack_size) \ |
|
|
|
size_t val = src.layout[_idx]; \ |
|
|
|
if (val % _pack_size != 0) { \ |
|
|
|
padding_to_workspace({workspace.raw_ptr, exec_src}, src, _idx, \ |
|
|
|
_pack_size); \ |
|
|
|
exec_src_nd.raw_ptr = workspace.raw_ptr; \ |
|
|
|
} |
|
|
|
cb(1, 8); |
|
|
|
#define cb(_idx, _pack_size, _mode) \ |
|
|
|
MIDOUT_BEGIN(megdnn_naive_relayout_format, \ |
|
|
|
midout_iv(Param::Mode::_mode)) { \ |
|
|
|
size_t val = src.layout[_idx]; \ |
|
|
|
if (val % _pack_size != 0) { \ |
|
|
|
padding_to_workspace({workspace.raw_ptr, exec_src}, src, _idx, \ |
|
|
|
_pack_size); \ |
|
|
|
exec_src_nd.raw_ptr = workspace.raw_ptr; \ |
|
|
|
} \ |
|
|
|
} \ |
|
|
|
MIDOUT_END(); |
|
|
|
cb(1, 8, NCHW_NCHW88); |
|
|
|
|
|
|
|
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT) { |
|
|
|
megdnn_assert(src.layout[0] % 8 == 0); |
|
|
|
cb(1, 8); |
|
|
|
cb(1, 8, NCHW_NCHW88_CONV_DENSE_WEIGHT); |
|
|
|
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT) { |
|
|
|
cb(0, 8); |
|
|
|
cb(0, 8, NCHW_NCHW88_CONV_CHAN_WEIGHT); |
|
|
|
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT) { |
|
|
|
megdnn_assert(src.layout[1] % 8 == 0); |
|
|
|
cb(2, 8); |
|
|
|
cb(2, 8, NCHW_NCHW88_CONV_GROUP_WEIGHT); |
|
|
|
} else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL) { |
|
|
|
cb(1, 4); |
|
|
|
cb(1, 4, NCHW_NCHW4_IC_SMALL); |
|
|
|
} else if (param().mode == |
|
|
|
Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) { |
|
|
|
cb(1, 4); |
|
|
|
cb(1, 4, NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT); |
|
|
|
} |
|
|
|
m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); |
|
|
|
#undef cb |
|
|
|
} |
|
|
|
|
|
|
|
// vim: syntax=cpp.doxygen |