GitOrigin-RevId: 613a18dd91
release-1.5
@@ -20,6 +20,8 @@ | |||||
#include "src/x86/pooling/pooling_special_cases.h" | #include "src/x86/pooling/pooling_special_cases.h" | ||||
#include "src/x86/utils.h" | #include "src/x86/utils.h" | ||||
#include "src/x86/avx_helper.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace x86; | using namespace x86; | ||||
@@ -65,6 +67,7 @@ PoolingImpl::AlgoPack::AlgoPack() { | |||||
all_algos.push_back(&algo_mean_w2s2_sse3); | all_algos.push_back(&algo_mean_w2s2_sse3); | ||||
all_algos.push_back(&algo_max_w2s2_sse); | all_algos.push_back(&algo_max_w2s2_sse); | ||||
all_algos.push_back(&algo_max_w3s3_sse); | all_algos.push_back(&algo_max_w3s3_sse); | ||||
all_algos.push_back(&algo_max_w13s1_nchw88_avx); | |||||
#if MEGDNN_X86_WITH_MKL_DNN | #if MEGDNN_X86_WITH_MKL_DNN | ||||
all_algos.push_back(&algo_mkldnn_nchw); | all_algos.push_back(&algo_mkldnn_nchw); | ||||
all_algos.push_back(&algo_mkldnn_nchw88); | all_algos.push_back(&algo_mkldnn_nchw88); | ||||
@@ -362,4 +365,136 @@ void PoolingImpl::AlgoMKLDNNNCHW88::exec(const ExecArgs& args) const { | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(run()); | MEGDNN_DISPATCH_CPU_KERN_OPR(run()); | ||||
} | } | ||||
#endif | |||||
#endif | |||||
namespace { | |||||
MEGDNN_ATTRIBUTE_TARGET("avx") | |||||
void max_pooling_s1_nchw88_avx_kern(const float* src, float* dst, int IH, | |||||
int IW, int OH, int OW, int PH, int PW, | |||||
int WH, int WW) { | |||||
static float min_float = -std::numeric_limits<float>::max(); | |||||
static int VECSIZE = 8; | |||||
__m256 ymm[16]; | |||||
const float* psrc = src; | |||||
float* pdst = dst; | |||||
//! deal all rows | |||||
for (int row = 0; row < IH; ++row) { | |||||
for (int j = 0; j < PW; ++j) { | |||||
ymm[j] = _mm256_set1_ps(min_float); | |||||
} | |||||
int col_end = WW - PW < IW ? WW - PW : IW; | |||||
for (int j = 0; j < col_end; ++j) { | |||||
ymm[j + PW] = _mm256_loadu_ps(psrc + j * VECSIZE); | |||||
} | |||||
for (int j = col_end + PW; j < WW; ++j) { | |||||
ymm[j] = _mm256_set1_ps(min_float); | |||||
} | |||||
int col_next = WW - PW; | |||||
for (int j = 0; j < OW; ++j) { | |||||
for (int i = WW - 2; i >= 0; --i) { | |||||
ymm[i] = _mm256_max_ps(ymm[i], ymm[i + 1]); | |||||
} | |||||
_mm256_storeu_ps(pdst, ymm[0]); | |||||
pdst += VECSIZE; | |||||
for (int i = 0; i < WW - 1; ++i) { | |||||
ymm[i] = ymm[i + 1]; | |||||
} | |||||
if (col_next < IW) { | |||||
ymm[WW - 1] = _mm256_loadu_ps(psrc + col_next * VECSIZE); | |||||
col_next++; | |||||
} else { | |||||
ymm[WW - 1] = _mm256_set1_ps(min_float); | |||||
} | |||||
} | |||||
psrc += IW * VECSIZE; | |||||
} | |||||
//! deal all cols | |||||
float* src1 = dst; | |||||
for (int col = 0; col < OW; ++col) { | |||||
for (int j = 0; j < PH; ++j) { | |||||
ymm[j] = _mm256_set1_ps(min_float); | |||||
} | |||||
int row_end = WH - PH < IH ? WH - PH : IH; | |||||
for (int j = 0; j < row_end; ++j) { | |||||
ymm[j + PH] = _mm256_loadu_ps(src1 + j * OW * VECSIZE); | |||||
} | |||||
for (int j = row_end + PH; j < WH; ++j) { | |||||
ymm[j] = _mm256_set1_ps(min_float); | |||||
} | |||||
int row_next = WH - PH; | |||||
pdst = src1; | |||||
for (int j = 0; j < OH; ++j) { | |||||
for (int i = WH - 2; i >= 0; --i) { | |||||
ymm[i] = _mm256_max_ps(ymm[i], ymm[i + 1]); | |||||
} | |||||
_mm256_storeu_ps(pdst, ymm[0]); | |||||
pdst += OW * VECSIZE; | |||||
for (int i = 0; i < WH - 1; ++i) { | |||||
ymm[i] = ymm[i + 1]; | |||||
} | |||||
if (row_next < IH) { | |||||
ymm[WH - 1] = _mm256_loadu_ps(src1 + row_next * OW * VECSIZE); | |||||
row_next++; | |||||
} else { | |||||
ymm[WH - 1] = _mm256_set1_ps(min_float); | |||||
} | |||||
} | |||||
src1 += VECSIZE; | |||||
} | |||||
} | |||||
} // namespace | |||||
bool PoolingImpl::AlgoMaxS1NCHW88AVX::is_available(const SizeArgs& args) const { | |||||
bool is_dtype_ok = args.layout_src.dtype == dtype::Float32(); | |||||
bool is_mode_ok = args.opr->param().mode == Mode::MAX; | |||||
bool is_format_ok = args.opr->param().format == Param::Format::NCHW88; | |||||
bool is_shape_ok = args.opr->param().window_h >= 10 && | |||||
args.opr->param().window_h <= 15 && | |||||
args.opr->param().window_w >= 10 && | |||||
args.opr->param().window_w <= 15; | |||||
bool is_stride_ok = | |||||
args.opr->param().stride_h == 1 && args.opr->param().stride_w == 1; | |||||
//! this condition guarantee size of dst's memory is bigger enough because | |||||
//! dst's memory will be used as workspace to store intermediate result. | |||||
bool is_pad_ok = | |||||
args.opr->param().pad_h >= args.opr->param().window_h / 2 && | |||||
args.opr->param().pad_w >= args.opr->param().window_w / 2; | |||||
bool is_ins_ok = is_supported(SIMDType::AVX); | |||||
return is_dtype_ok && is_mode_ok && is_format_ok && is_shape_ok && | |||||
is_pad_ok && is_stride_ok && is_ins_ok; | |||||
} | |||||
void PoolingImpl::AlgoMaxS1NCHW88AVX::exec(const ExecArgs& args) const { | |||||
auto handle = args.handle; | |||||
size_t N = args.layout_src.shape[0]; | |||||
static size_t VECSIZE = 8; | |||||
size_t PH = args.opr->param().pad_h; | |||||
size_t PW = args.opr->param().pad_w; | |||||
size_t WH = args.opr->param().window_h; | |||||
size_t WW = args.opr->param().window_w; | |||||
size_t IC = args.layout_src.shape[1]; | |||||
size_t IH = args.layout_src.shape[2]; | |||||
size_t IW = args.layout_src.shape[3]; | |||||
size_t OH = args.layout_dst.shape[2]; | |||||
size_t OW = args.layout_dst.shape[3]; | |||||
float* src_ptr = reinterpret_cast<float*>(args.src_tensor->raw_ptr); | |||||
float* dst_ptr = reinterpret_cast<float*>(args.dst_tensor->raw_ptr); | |||||
auto run = [IC, src_ptr, dst_ptr, IH, IW, OH, OW, PH, PW, WH, WW]( | |||||
size_t index, size_t) { | |||||
size_t n = index / IC; | |||||
size_t c = index % IC; | |||||
float* src = | |||||
src_ptr + n * IH * IW * IC * VECSIZE + IH * IW * c * VECSIZE; | |||||
float* dst = | |||||
dst_ptr + n * OH * OW * IC * VECSIZE + OH * OW * c * VECSIZE; | |||||
max_pooling_s1_nchw88_avx_kern(src, dst, IH, IW, OH, OW, PH, PW, WH, | |||||
WW); | |||||
}; | |||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN(handle, N * IC, run); | |||||
} |
@@ -29,6 +29,7 @@ public: | |||||
X86_MeanW2S2SSE3, | X86_MeanW2S2SSE3, | ||||
X86_MaxW2S2SSE, | X86_MaxW2S2SSE, | ||||
X86_MaxW3S3SSE, | X86_MaxW3S3SSE, | ||||
X86_MaxS1NCHW88AVX, | |||||
#if MEGDNN_X86_WITH_MKL_DNN | #if MEGDNN_X86_WITH_MKL_DNN | ||||
X86_MKLDNNNCHW, | X86_MKLDNNNCHW, | ||||
X86_MKLDNNNCHW88, | X86_MKLDNNNCHW88, | ||||
@@ -87,11 +88,11 @@ ALGO_IMPL(MeanW2S2AVX) | |||||
ALGO_IMPL(MeanW2S2SSE3) | ALGO_IMPL(MeanW2S2SSE3) | ||||
ALGO_IMPL(MaxW2S2SSE) | ALGO_IMPL(MaxW2S2SSE) | ||||
ALGO_IMPL(MaxW3S3SSE) | ALGO_IMPL(MaxW3S3SSE) | ||||
ALGO_IMPL(MaxS1NCHW88AVX) | |||||
#if MEGDNN_X86_WITH_MKL_DNN | #if MEGDNN_X86_WITH_MKL_DNN | ||||
ALGO_IMPL(MKLDNNNCHW) | ALGO_IMPL(MKLDNNNCHW) | ||||
ALGO_IMPL(MKLDNNNCHW88) | ALGO_IMPL(MKLDNNNCHW88) | ||||
#endif | #endif | ||||
#undef ALGO_IMPL | #undef ALGO_IMPL | ||||
class PoolingImpl::AlgoFallback final : public AlgoBase { | class PoolingImpl::AlgoFallback final : public AlgoBase { | ||||
@@ -118,6 +119,7 @@ private: | |||||
AlgoMKLDNNNCHW algo_mkldnn_nchw; | AlgoMKLDNNNCHW algo_mkldnn_nchw; | ||||
AlgoMKLDNNNCHW88 algo_mkldnn_nchw88; | AlgoMKLDNNNCHW88 algo_mkldnn_nchw88; | ||||
#endif | #endif | ||||
AlgoMaxS1NCHW88AVX algo_max_w13s1_nchw88_avx; | |||||
AlgoFallback algo_fallback; | AlgoFallback algo_fallback; | ||||
public: | public: | ||||
@@ -21,6 +21,7 @@ private: | |||||
class AlgoMeanW2S2SSE3; | class AlgoMeanW2S2SSE3; | ||||
class AlgoMaxW2S2SSE; | class AlgoMaxW2S2SSE; | ||||
class AlgoMaxW3S3SSE; | class AlgoMaxW3S3SSE; | ||||
class AlgoMaxS1NCHW88AVX; | |||||
#if MEGDNN_X86_WITH_MKL_DNN | #if MEGDNN_X86_WITH_MKL_DNN | ||||
class AlgoMKLDNNNCHW; | class AlgoMKLDNNNCHW; | ||||
class AlgoMKLDNNNCHW88; | class AlgoMKLDNNNCHW88; | ||||
@@ -24,6 +24,70 @@ TEST_F(X86, POOLING) { | |||||
} | } | ||||
} | } | ||||
TEST_F(X86, S1POOLING88) { | |||||
Checker<Pooling> checker(handle()); | |||||
auto run = [&](size_t WH, size_t WW, size_t PH, size_t PW, size_t SH, | |||||
size_t SW, size_t N, size_t C, size_t H, size_t W) { | |||||
Pooling::Param param; | |||||
param.format = param::Pooling::Format::NCHW88; | |||||
param.window_h = WH; | |||||
param.window_w = WW; | |||||
param.pad_h = PH; | |||||
param.pad_w = PW; | |||||
param.stride_w = SW; | |||||
param.stride_h = SH; | |||||
param.mode = param::Pooling::Mode::MAX; | |||||
checker.set_param(param); | |||||
checker.execs({{N, C, H, W, 8}, {}}); | |||||
}; | |||||
for (size_t wh = 10; wh < 15; ++wh) { | |||||
for (size_t ww = 10; ww < 15; ++ww) { | |||||
for (size_t n : {1, 2, 4}) { | |||||
for (size_t c : {1, 4}) { | |||||
for (size_t h : {10, 13, 20}) { | |||||
for (size_t w : {10, 13, 20}) { | |||||
run(wh, ww, wh / 2, ww / 2, 1, 1, n, c, h, w); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
TEST_F(X86_MULTI_THREADS, S1POOLING88) { | |||||
Checker<Pooling> checker(handle()); | |||||
auto run = [&](size_t WH, size_t WW, size_t PH, size_t PW, size_t SH, | |||||
size_t SW, size_t N, size_t C, size_t H, size_t W) { | |||||
Pooling::Param param; | |||||
param.format = param::Pooling::Format::NCHW88; | |||||
param.window_h = WH; | |||||
param.window_w = WW; | |||||
param.pad_h = PH; | |||||
param.pad_w = PW; | |||||
param.stride_w = SW; | |||||
param.stride_h = SH; | |||||
param.mode = param::Pooling::Mode::MAX; | |||||
checker.set_param(param); | |||||
checker.execs({{N, C, H, W, 8}, {}}); | |||||
}; | |||||
for (size_t wh = 10; wh < 15; ++wh) { | |||||
for (size_t ww = 10; ww < 15; ++ww) { | |||||
for (size_t n : {1, 2, 4}) { | |||||
for (size_t c : {1, 4}) { | |||||
for (size_t h : {10, 13, 20}) { | |||||
for (size_t w : {10, 13, 20}) { | |||||
run(wh, ww, wh / 2, ww / 2, 1, 1, n, c, h, w); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#if MEGDNN_X86_WITH_MKL_DNN | #if MEGDNN_X86_WITH_MKL_DNN | ||||
TEST_F(X86, POOLING88) { | TEST_F(X86, POOLING88) { | ||||
Checker<Pooling> checker(handle()); | Checker<Pooling> checker(handle()); | ||||
@@ -104,6 +168,42 @@ TEST_F(X86, BENCHMARK_POOLING) { | |||||
TEST_F(X86_MULTI_THREADS, BENCHMARK_POOLING) { | TEST_F(X86_MULTI_THREADS, BENCHMARK_POOLING) { | ||||
test_x86_megdnn_pooling(handle()); | test_x86_megdnn_pooling(handle()); | ||||
} | } | ||||
TEST_F(X86, BENCHMARK_POOLING_MAX_S1_NCHW88) { | |||||
constexpr size_t RUNS = 50; | |||||
auto x86_handle = handle(); | |||||
Benchmarker<Pooling> benchmarker_pooling(x86_handle); | |||||
benchmarker_pooling.set_times(RUNS); | |||||
auto run = [&](uint32_t pad, uint32_t stride, uint32_t window_size, | |||||
size_t in_number, size_t in_channel, size_t in_height, | |||||
size_t in_width) { | |||||
auto opr = x86_handle->create_operator<Pooling>(); | |||||
opr->param() = {param::Pooling::Mode::MAX, | |||||
pad, | |||||
pad, | |||||
stride, | |||||
stride, | |||||
window_size, | |||||
window_size}; | |||||
opr->param().format = param::Pooling::Format::NCHW88; | |||||
TensorShape shape{in_number, in_channel / 8, in_height, in_width, 8}; | |||||
TensorLayout dst_layout; | |||||
opr->deduce_layout({shape, dtype::Float32()}, dst_layout); | |||||
float computation = | |||||
dst_layout.total_nr_elems() * window_size * window_size * 1e-9; | |||||
auto pooling_used = benchmarker_pooling.set_param(opr->param()) | |||||
.exec(TensorShapeArray{shape, {}}) / | |||||
RUNS; | |||||
float through_put = computation / pooling_used * 1e3; | |||||
printf("profiling max pooling NCHW88 {%zu,%zu,%zu,%zu,8}\nuse time : " | |||||
"%f ms\nthrough_put : %f Gflops\n", | |||||
in_number, in_channel / 8, in_height, in_width, pooling_used, | |||||
through_put); | |||||
}; | |||||
run(6, 1, 13, 1, 32 * 8, 20, 20); | |||||
} | |||||
#endif | #endif | ||||
#if MEGDNN_X86_WITH_MKL_DNN | #if MEGDNN_X86_WITH_MKL_DNN | ||||
TEST_F(X86, POOLING_INT8) { | TEST_F(X86, POOLING_INT8) { | ||||