Browse Source

fix(mgb): fix rocm pooling

GitOrigin-RevId: 44876d398e
release-1.6
Megvii Engine Team 3 years ago
parent
commit
3abe0b2462
1 changed files with 5 additions and 5 deletions
  1. +5
    -5
      dnn/src/rocm/pooling/algo.cpp

+ 5
- 5
dnn/src/rocm/pooling/algo.cpp View File

@@ -60,10 +60,10 @@ void PoolingForwardImpl::AlgoMIOpen::init_mode(
case param::Pooling::Mode::MAX:
mode = miopenPoolingMax;
break;
case param::Pooling::Mode::AVERAGE:
case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING:
mode = miopenPoolingAverage;
break;
case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING:
case param::Pooling::Mode::AVERAGE:
mode = miopenPoolingAverageInclusive;
break;
default:
@@ -96,7 +96,7 @@ void PoolingForwardImpl::AlgoMIOpen::exec(const ExecArgs& args) const {
miopen_check(miopenPoolingForward(
handle, miopen_desc, &alpha, src_desc.desc,
args.src_tensor->raw_ptr, &beta, dst_desc.desc,
args.src_tensor->raw_ptr, false, nullptr, 0_z));
args.dst_tensor->raw_ptr, false, nullptr, 0_z));
miopen_check(miopenDestroyPoolingDescriptor(miopen_desc));
}

@@ -163,10 +163,10 @@ void PoolingBackwardImpl::AlgoMIOpen::init_mode(const ExecArgs& args,
case param::Pooling::Mode::MAX:
mode = miopenPoolingMax;
break;
case param::Pooling::Mode::AVERAGE:
case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING:
mode = miopenPoolingAverage;
break;
case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING:
case param::Pooling::Mode::AVERAGE:
mode = miopenPoolingAverageInclusive;
break;
default:


Loading…
Cancel
Save