|
|
@@ -370,65 +370,67 @@ void pooling_backward_max_impl(const ctype* __restrict src, |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
} // anonymous namespace |
|
|
|
} // namespace |
|
|
|
|
|
|
|
namespace megdnn { |
|
|
|
namespace naive { |
|
|
|
|
|
|
|
void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, |
|
|
|
_megdnn_workspace workspace) { |
|
|
|
MIDOUT_BEGIN(megdnn_naive_pooling) { |
|
|
|
check_exec(src.layout, dst.layout, workspace.size); |
|
|
|
size_t c_pos, spatial_pos, batch_pos = 0; |
|
|
|
if (param().format == Param::Format::NCHW || |
|
|
|
param().format == Param::Format::NCHW4 || |
|
|
|
param().format == Param::Format::NCHW88 || |
|
|
|
param().format == Param::Format::NCHW44 || |
|
|
|
param().format == Param::Format::NCHW32) { |
|
|
|
c_pos = 1; |
|
|
|
spatial_pos = 2; |
|
|
|
} else if (param().format == Param::Format::NHWC) { |
|
|
|
c_pos = 3; |
|
|
|
spatial_pos = 1; |
|
|
|
} else if (param().format == Param::Format::CHWN4) { |
|
|
|
c_pos = 0; |
|
|
|
spatial_pos = 1; |
|
|
|
batch_pos = 3; |
|
|
|
} else { |
|
|
|
megdnn_assert(param().format == Param::Format::NHWCD4); |
|
|
|
c_pos = 2; |
|
|
|
spatial_pos = 1; |
|
|
|
} |
|
|
|
size_t N = src.layout.shape[batch_pos], C = src.layout.shape[c_pos], |
|
|
|
IH = src.layout.shape[spatial_pos + 0], |
|
|
|
IW = src.layout.shape[spatial_pos + 1]; |
|
|
|
size_t OH = dst.layout.shape[spatial_pos + 0], |
|
|
|
OW = dst.layout.shape[spatial_pos + 1]; |
|
|
|
if (param().format == Param::Format::NHWCD4) { |
|
|
|
C *= 4; |
|
|
|
IW = src.layout.shape[spatial_pos + 2]; |
|
|
|
OW = dst.layout.shape[spatial_pos + 2]; |
|
|
|
} |
|
|
|
if (param().format == Param::Format::NCHW4 || |
|
|
|
param().format == Param::Format::NCHW44 || |
|
|
|
param().format == Param::Format::CHWN4) { |
|
|
|
C *= 4; |
|
|
|
} |
|
|
|
if (param().format == Param::Format::NCHW88) { |
|
|
|
C *= 8; |
|
|
|
} |
|
|
|
if (param().format == Param::Format::NCHW32) { |
|
|
|
C *= 32; |
|
|
|
} |
|
|
|
size_t PH = param().pad_h, PW = param().pad_w; |
|
|
|
size_t FH = param().window_h, FW = param().window_w; |
|
|
|
size_t SH = param().stride_h, SW = param().stride_w; |
|
|
|
#define DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, IdxGetter) \ |
|
|
|
MEGDNN_DISPATCH_CPU_KERN( \ |
|
|
|
static_cast<naive::HandleImpl*>(handle()), \ |
|
|
|
pooling_forward_impl<Pooler MEGDNN_COMMA IdxGetter>( \ |
|
|
|
sptr, dptr, src.layout.dtype, N, C, IH, IW, OH, OW, PH, \ |
|
|
|
PW, SH, SW, FH, FW)); |
|
|
|
check_exec(src.layout, dst.layout, workspace.size); |
|
|
|
size_t c_pos, spatial_pos, batch_pos = 0; |
|
|
|
if (param().format == Param::Format::NCHW || |
|
|
|
param().format == Param::Format::NCHW4 || |
|
|
|
param().format == Param::Format::NCHW88 || |
|
|
|
param().format == Param::Format::NCHW44 || |
|
|
|
param().format == Param::Format::NCHW32) { |
|
|
|
c_pos = 1; |
|
|
|
spatial_pos = 2; |
|
|
|
} else if (param().format == Param::Format::NHWC) { |
|
|
|
c_pos = 3; |
|
|
|
spatial_pos = 1; |
|
|
|
} else if (param().format == Param::Format::CHWN4) { |
|
|
|
c_pos = 0; |
|
|
|
spatial_pos = 1; |
|
|
|
batch_pos = 3; |
|
|
|
} else { |
|
|
|
megdnn_assert(param().format == Param::Format::NHWCD4); |
|
|
|
c_pos = 2; |
|
|
|
spatial_pos = 1; |
|
|
|
} |
|
|
|
size_t N = src.layout.shape[batch_pos], C = src.layout.shape[c_pos], |
|
|
|
IH = src.layout.shape[spatial_pos + 0], |
|
|
|
IW = src.layout.shape[spatial_pos + 1]; |
|
|
|
size_t OH = dst.layout.shape[spatial_pos + 0], |
|
|
|
OW = dst.layout.shape[spatial_pos + 1]; |
|
|
|
if (param().format == Param::Format::NHWCD4) { |
|
|
|
C *= 4; |
|
|
|
IW = src.layout.shape[spatial_pos + 2]; |
|
|
|
OW = dst.layout.shape[spatial_pos + 2]; |
|
|
|
} |
|
|
|
if (param().format == Param::Format::NCHW4 || |
|
|
|
param().format == Param::Format::NCHW44 || |
|
|
|
param().format == Param::Format::CHWN4) { |
|
|
|
C *= 4; |
|
|
|
} |
|
|
|
if (param().format == Param::Format::NCHW88) { |
|
|
|
C *= 8; |
|
|
|
} |
|
|
|
if (param().format == Param::Format::NCHW32) { |
|
|
|
C *= 32; |
|
|
|
} |
|
|
|
size_t PH = param().pad_h, PW = param().pad_w; |
|
|
|
size_t FH = param().window_h, FW = param().window_w; |
|
|
|
size_t SH = param().stride_h, SW = param().stride_w; |
|
|
|
#define DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, IdxGetter) \ |
|
|
|
MIDOUT_BEGIN(megdnn_naive_pooling, midout_iv(#Pooler #IdxGetter##_hash)) { \ |
|
|
|
MEGDNN_DISPATCH_CPU_KERN( \ |
|
|
|
static_cast<naive::HandleImpl*>(handle()), \ |
|
|
|
pooling_forward_impl<Pooler MEGDNN_COMMA IdxGetter>( \ |
|
|
|
sptr, dptr, src.layout.dtype, N, C, IH, IW, OH, OW, \ |
|
|
|
PH, PW, SH, SW, FH, FW)); \ |
|
|
|
} \ |
|
|
|
MIDOUT_END(); |
|
|
|
|
|
|
|
#define DISPATCH_WITH_POOLER(Pooler) \ |
|
|
|
switch (param().format) { \ |
|
|
@@ -484,14 +486,12 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, |
|
|
|
} \ |
|
|
|
} \ |
|
|
|
} |
|
|
|
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) |
|
|
|
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) |
|
|
|
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) |
|
|
|
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) |
|
|
|
#undef cb |
|
|
|
#undef DISPATCH_WITH_POOLER_AND_IDX_GETTER |
|
|
|
#undef DISPATCH_WITH_POOLER |
|
|
|
megdnn_assert_internal(0); |
|
|
|
} |
|
|
|
MIDOUT_END(); |
|
|
|
megdnn_assert_internal(0); |
|
|
|
} |
|
|
|
|
|
|
|
WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle( |
|
|
|