Browse Source

fix(dnn/naive): fix midout for pooling

GitOrigin-RevId: 4edd99f3ec
release-0.5
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
b90c1540db
1 changed files with 58 additions and 58 deletions
  1. +58
    -58
      dnn/src/naive/pooling/opr_impl.cpp

+ 58
- 58
dnn/src/naive/pooling/opr_impl.cpp View File

@@ -370,65 +370,67 @@ void pooling_backward_max_impl(const ctype* __restrict src,
}
}

} // anonymous namespace
} // namespace

namespace megdnn {
namespace naive {

void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
MIDOUT_BEGIN(megdnn_naive_pooling) {
check_exec(src.layout, dst.layout, workspace.size);
size_t c_pos, spatial_pos, batch_pos = 0;
if (param().format == Param::Format::NCHW ||
param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW88 ||
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW32) {
c_pos = 1;
spatial_pos = 2;
} else if (param().format == Param::Format::NHWC) {
c_pos = 3;
spatial_pos = 1;
} else if (param().format == Param::Format::CHWN4) {
c_pos = 0;
spatial_pos = 1;
batch_pos = 3;
} else {
megdnn_assert(param().format == Param::Format::NHWCD4);
c_pos = 2;
spatial_pos = 1;
}
size_t N = src.layout.shape[batch_pos], C = src.layout.shape[c_pos],
IH = src.layout.shape[spatial_pos + 0],
IW = src.layout.shape[spatial_pos + 1];
size_t OH = dst.layout.shape[spatial_pos + 0],
OW = dst.layout.shape[spatial_pos + 1];
if (param().format == Param::Format::NHWCD4) {
C *= 4;
IW = src.layout.shape[spatial_pos + 2];
OW = dst.layout.shape[spatial_pos + 2];
}
if (param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::CHWN4) {
C *= 4;
}
if (param().format == Param::Format::NCHW88) {
C *= 8;
}
if (param().format == Param::Format::NCHW32) {
C *= 32;
}
size_t PH = param().pad_h, PW = param().pad_w;
size_t FH = param().window_h, FW = param().window_w;
size_t SH = param().stride_h, SW = param().stride_w;
#define DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, IdxGetter) \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(handle()), \
pooling_forward_impl<Pooler MEGDNN_COMMA IdxGetter>( \
sptr, dptr, src.layout.dtype, N, C, IH, IW, OH, OW, PH, \
PW, SH, SW, FH, FW));
check_exec(src.layout, dst.layout, workspace.size);
size_t c_pos, spatial_pos, batch_pos = 0;
if (param().format == Param::Format::NCHW ||
param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW88 ||
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW32) {
c_pos = 1;
spatial_pos = 2;
} else if (param().format == Param::Format::NHWC) {
c_pos = 3;
spatial_pos = 1;
} else if (param().format == Param::Format::CHWN4) {
c_pos = 0;
spatial_pos = 1;
batch_pos = 3;
} else {
megdnn_assert(param().format == Param::Format::NHWCD4);
c_pos = 2;
spatial_pos = 1;
}
size_t N = src.layout.shape[batch_pos], C = src.layout.shape[c_pos],
IH = src.layout.shape[spatial_pos + 0],
IW = src.layout.shape[spatial_pos + 1];
size_t OH = dst.layout.shape[spatial_pos + 0],
OW = dst.layout.shape[spatial_pos + 1];
if (param().format == Param::Format::NHWCD4) {
C *= 4;
IW = src.layout.shape[spatial_pos + 2];
OW = dst.layout.shape[spatial_pos + 2];
}
if (param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::CHWN4) {
C *= 4;
}
if (param().format == Param::Format::NCHW88) {
C *= 8;
}
if (param().format == Param::Format::NCHW32) {
C *= 32;
}
size_t PH = param().pad_h, PW = param().pad_w;
size_t FH = param().window_h, FW = param().window_w;
size_t SH = param().stride_h, SW = param().stride_w;
#define DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, IdxGetter) \
MIDOUT_BEGIN(megdnn_naive_pooling, midout_iv(#Pooler #IdxGetter##_hash)) { \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(handle()), \
pooling_forward_impl<Pooler MEGDNN_COMMA IdxGetter>( \
sptr, dptr, src.layout.dtype, N, C, IH, IW, OH, OW, \
PH, PW, SH, SW, FH, FW)); \
} \
MIDOUT_END();

#define DISPATCH_WITH_POOLER(Pooler) \
switch (param().format) { \
@@ -484,14 +486,12 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
} \
} \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
#undef cb
#undef DISPATCH_WITH_POOLER_AND_IDX_GETTER
#undef DISPATCH_WITH_POOLER
megdnn_assert_internal(0);
}
MIDOUT_END();
megdnn_assert_internal(0);
}

WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle(


Loading…
Cancel
Save