Browse Source

feat(megdnn/rocm): add average inclusive mode for pooling

GitOrigin-RevId: 08162bd1fb
release-1.1
Megvii Engine Team 4 years ago
parent
commit
aea829c9fa
2 changed files with 5 additions and 10 deletions
  1. +4
    -0
      dnn/src/rocm/miopen_wrapper.cpp
  2. +1
    -10
      dnn/test/rocm/pooling.cpp

+ 4
- 0
dnn/src/rocm/miopen_wrapper.cpp View File

@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "hcc_detail/hcc_defs_prologue.h"
#include "megdnn/opr_param_defs.h"
#include "src/rocm/miopen_wrapper.h"

#include "src/common/utils.h"
@@ -114,6 +115,9 @@ void PoolingDesc::set(const param::Pooling& param) {
case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING:
mode = miopenPoolingAverage;
break;
case param::Pooling::Mode::AVERAGE:
mode = miopenPoolingAverageInclusive;
break;
default:
megdnn_throw(megdnn_mangle("Unsupported pooling mode for miopen"));
}


+ 1
- 10
dnn/test/rocm/pooling.cpp View File

@@ -31,10 +31,6 @@ TEST_F(ROCM, POOLING_FORWARD) {
for (auto format : {Format::NCHW})
for (auto&& arg : args) {
auto param = arg.param;
if (param.mode == param::Pooling::Mode::AVERAGE) {
param.mode =
param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
}
auto src = arg.ishape;
param.format = format;
Checker<Pooling> checker(handle_rocm());
@@ -53,11 +49,6 @@ TEST_F(ROCM, POOLING_BACKWARD) {
TensorLayout ilayout = TensorLayout(arg.ishape, dtype::Float32());
TensorLayout olayout;

auto& param = arg.param;
if (param.mode == param::Pooling::Mode::AVERAGE) {
param.mode = param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
}

auto constraint = [this,
arg](CheckerHelper::TensorValueArray& tensors_orig) {
megdnn_assert(tensors_orig.size() == 4);
@@ -141,7 +132,7 @@ TEST_F(ROCM, POOLING_FWD_BENCHMARK) {
benchmarker.set_param(param);
size_t OH = infer_conv_shape(IH, FH, SH, PH);
size_t OW = infer_conv_shape(IW, FW, SW, PW);
// warm up
// warm up
benchmarker.execs({{N, IC, IH, IW}, {N, IC, OH, OW}});
// do actual benchmark
auto time_ms = benchmarker.execs({{N, IC, IH, IW}, {N, IC, OH, OW}});


Loading…
Cancel
Save