Browse Source

feat(dnn): add conv_bwd_data and conv_bwd_filter accuracy shake check

GitOrigin-RevId: 4069e083d2
release-1.4
Megvii Engine Team 4 years ago
parent
commit
747b53c018
3 changed files with 44 additions and 2 deletions
  1. +7
    -1
      dnn/src/cuda/convolution/backward_data/algo.h
  2. +2
    -1
      dnn/test/common/accuracy_shake_checker.h
  3. +35
    -0
      dnn/test/cuda/accuracy_shake.cpp

+ 7
- 1
dnn/src/cuda/convolution/backward_data/algo.h View File

@@ -236,7 +236,13 @@ public:
TensorLayout& grad_pg);
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
auto ret = AlgoAttribute::DEFAULT;
#define cb(attr) \
if (m_impl->contain_attribute_all(attr)) { \
ret |= attr; \
}
MEGDNN_FOREACH_ALGO_ATTRIBUTE_INHERITABLE(cb)
#undef cb
if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) {
ret |= AlgoAttribute::REPRODUCIBLE;
}


+ 2
- 1
dnn/test/common/accuracy_shake_checker.h View File

@@ -168,7 +168,8 @@ public:
AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info(
opr, layouts)) {
if (!(algo_info.attribute &
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH) &&
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH) &&
(algo_info.attribute & AlgoAttribute::REPRODUCIBLE) &&
std::regex_match(
algo_info.desc.name,
std::regex("(.*)(" + m_policy_name.name + ")(.*)"))) {


+ 35
- 0
dnn/test/cuda/accuracy_shake.cpp View File

@@ -241,6 +241,41 @@ TEST_F(CUDA, SHAKE_LOCAL_SHARE) {
checker.exec({{20, 16, 32, 32}, {3, 3, 16, 3, 3, 64}, {}});
}

TEST_F(CUDA, SHAKE_CONVOLUTION_BACKWARD_DATA) {
AccuracyShakeChecker<ConvolutionBackwardData> checker(handle_cuda());
NormalRNG default_rng;
checker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_rng(0, &default_rng)
.set_rng(1, &default_rng);
// ConvolutionBackwardData
checker.exec({{8, 16, 3, 3}, {64, 8, 5, 5}, {64, 16, 7, 7}});

// group
ConvolutionBackwardData::Param param;
param.sparse = Convolution::Param::Sparse::GROUP;
checker.set_param(param);
checker.exec({{2, 16, 32, 3, 3}, {2, 32, 5, 5}, {2, 64, 7, 7}});
checker.exec({{2, 8, 32, 3, 3}, {64, 16, 19, 19}, {64, 64, 21, 21}});
}

TEST_F(CUDA, SHAKE_CONVOLUTION_BACKWARD_FILTER) {
AccuracyShakeChecker<ConvolutionBackwardFilter> checker(handle_cuda());
NormalRNG default_rng;
checker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_rng(0, &default_rng)
.set_rng(1, &default_rng);
// ConvolutionBackwardFilter
checker.exec({{2, 64, 7, 7}, {2, 32, 5, 5}, {32, 64, 3, 3}});

// group
ConvolutionBackwardFilter::Param param;
param.sparse = Convolution::Param::Sparse::GROUP;
checker.set_param(param);
checker.exec({{2, 64, 7, 7}, {2, 32, 5, 5}, {2, 16, 32, 3, 3}});
}

} // namespace test
} // namespace megdnn



Loading…
Cancel
Save