From 3d3666b6e0f7ff76f089297a27917725b441e5fd Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 23 Aug 2021 16:51:22 +0800 Subject: [PATCH] test(dnn/bn): add compatible configs for NHWC BN GitOrigin-RevId: ac757ca307f53ee2af9af8e3943a1d7776fa6c37 --- dnn/test/common/bn.h | 14 ++++++++- dnn/test/common/checker.cpp | 2 +- dnn/test/common/checker.h | 5 ++++ dnn/test/common/deduce_layout_proxy.h | 9 ++++++ dnn/test/common/exec_proxy.h | 17 +++++++++++ dnn/test/common/rng.cpp | 4 +++ dnn/test/cuda/bn.cpp | 53 ++++++++++++++++++++++------------- dnn/test/rocm/bn.cpp | 11 +++++--- 8 files changed, 89 insertions(+), 26 deletions(-) diff --git a/dnn/test/common/bn.h b/dnn/test/common/bn.h index 5c67ccc2..e31c0119 100644 --- a/dnn/test/common/bn.h +++ b/dnn/test/common/bn.h @@ -53,6 +53,18 @@ std::vector get_args() { TensorShape{1, 3, 1, 1}, dtype::Float16()); } + // case 3: 1 x 1 x 1 x C + + for (size_t i = 4; i < 257; i *= 4) { + param::BN param; + param.fwd_mode = param::BN::FwdMode::TRAINING; + param.param_dim = param::BN::ParamDim::DIM_111C; + args.emplace_back(param, TensorShape{3, i, i, 3}, + TensorShape{1, 1, 1, 3}, dtype::Float32()); + args.emplace_back(param, TensorShape{3, i, i, 3}, + TensorShape{1, 1, 1, 3}, dtype::Float16()); + } + return args; } @@ -60,4 +72,4 @@ std::vector get_args() { } // namespace test } // namespace megdnn -// vim: syntax=cpp.doxygen \ No newline at end of file +// vim: syntax=cpp.doxygen diff --git a/dnn/test/common/checker.cpp b/dnn/test/common/checker.cpp index 26ee4fb6..50e473f4 100644 --- a/dnn/test/common/checker.cpp +++ b/dnn/test/common/checker.cpp @@ -419,7 +419,7 @@ void CheckerHelper::copy_tensors_from_device(const TensorValueArray& dest, void CheckerHelper::check_tensors(const TensorValueArray& expected, const TensorValueArray& computed) { for (size_t i = 0; i < expected.size(); ++i) { - if (expected[i].layout.ndim == 0) + if (expected[i].layout.ndim == 0 || m_bypass.find(i) != m_bypass.end()) continue; if (m_allow_invalid_check) { MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG_ALLOW_INVALID( diff --git a/dnn/test/common/checker.h b/dnn/test/common/checker.h index f7413135..61a60aea 100644 --- a/dnn/test/common/checker.h +++ b/dnn/test/common/checker.h @@ -69,6 +69,7 @@ protected: std::unordered_map m_rng; std::unordered_map m_dtype; std::unordered_map m_fmt; + std::set m_bypass; float_t m_epsilon = 1e-3, m_max_avg_error = 1e-3, m_max_avg_biased_error = 1e-3; float_t m_perf_check_threshold = -1; @@ -184,6 +185,10 @@ public: m_rng[idx] = rng; return *this; } + Checker& set_bypass(size_t idx) { + m_bypass.insert(idx); + return *this; + } //! max error of a single element Checker& set_epsilon(dt_float32 epsilon) { m_epsilon = epsilon; diff --git a/dnn/test/common/deduce_layout_proxy.h b/dnn/test/common/deduce_layout_proxy.h index 39b453f4..05794d60 100644 --- a/dnn/test/common/deduce_layout_proxy.h +++ b/dnn/test/common/deduce_layout_proxy.h @@ -82,6 +82,15 @@ struct DeduceLayoutProxy { } }; +template +struct DeduceLayoutProxy { + static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) { + megdnn_assert(layouts.size() == 9); + opr->deduce_layout(layouts[0], layouts[1], layouts[2], layouts[3], + layouts[4], layouts[5], layouts[6], layouts[7], + layouts[8]); + } +}; } // namespace test } // namespace megdnn diff --git a/dnn/test/common/exec_proxy.h b/dnn/test/common/exec_proxy.h index 1393f357..a4091a50 100644 --- a/dnn/test/common/exec_proxy.h +++ b/dnn/test/common/exec_proxy.h @@ -23,6 +23,23 @@ template struct ExecProxy; template +struct ExecProxy { + WorkspaceWrapper W; + void exec(Opr* opr, const TensorNDArray& tensors) { + if (!W.valid()) { + W = WorkspaceWrapper(opr->handle(), 0); + } + W.update(opr->get_workspace_in_bytes( + tensors[0].layout, tensors[1].layout, tensors[2].layout, + tensors[3].layout, tensors[4].layout, tensors[5].layout, + tensors[6].layout, tensors[7].layout, tensors[8].layout)); + opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], + tensors[5], tensors[6], tensors[7], tensors[8], + W.workspace()); + } +}; + +template struct ExecProxy { WorkspaceWrapper W; void exec(Opr* opr, const TensorNDArray& tensors) { diff --git a/dnn/test/common/rng.cpp b/dnn/test/common/rng.cpp index 70921a87..57004d02 100644 --- a/dnn/test/common/rng.cpp +++ b/dnn/test/common/rng.cpp @@ -211,6 +211,10 @@ void IIDRNG::gen(const TensorND& tensor) { } return; } + if (tensor.layout.dtype.enumv() == DTypeEnum::Byte) { + memset(tensor.raw_ptr, 0, tensor.layout.access_bytes()); + return; + } megdnn_assert(0, "IIDRNG does not know how to generate value for DType %s", tensor.layout.dtype.name()); } diff --git a/dnn/test/cuda/bn.cpp b/dnn/test/cuda/bn.cpp index d5202ef8..1aace733 100644 --- a/dnn/test/cuda/bn.cpp +++ b/dnn/test/cuda/bn.cpp @@ -6,10 +6,13 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "test/cuda/fixture.h" +#include "src/cuda/batch_normalization/opr_impl.h" +#include "src/cuda/utils.h" #include "megdnn/opr_param_defs.h" #include "megdnn/oprs.h" #include "test/common/bn.h" @@ -21,15 +24,26 @@ namespace megdnn { namespace test { -TEST_F(CUDA, BN_FORWARD) { +TEST_F(CUDA, BN_FORWARD_BACKWARD) { using namespace batch_normalization; + using cuda::cudnn_handle; + using cuda::batch_normalization::BNTensorDescHolder; + using cuda::batch_normalization::get_reserve_size; std::vector args = get_args(); Checker checker(handle_cuda()); + Checker checker_bwd(handle_cuda()); for (auto&& arg : args) { - for (int i = 0; i < 8; ++i) { + auto tensor_desc = BNTensorDescHolder({arg.src, arg.dtype}, arg.param.param_dim, + arg.param.fwd_mode); + auto reserve = get_reserve_size(cudnn_handle(handle_cuda()), tensor_desc); + // Forward + for (int i = 0; i < 9; ++i) { checker.set_dtype(i, dtype::Float32()); } checker.set_dtype(0, arg.dtype); + checker.set_dtype(7, dtype::Byte()); + checker.set_dtype(8, arg.dtype); + checker.set_bypass(7); checker.set_epsilon(1e-3).set_param(arg.param); for (bool need_statistic : {false, true}) checker.exec({ @@ -40,27 +54,26 @@ TEST_F(CUDA, BN_FORWARD) { : TensorShape({0}), // mean need_statistic ? arg.param_shape : TensorShape({0}), // variance - arg.param_shape, // batch_mean - arg.param_shape, // batch_inv_variance - {} // dst + arg.param_shape, // batch_mean + arg.param_shape, // batch_inv_variance + {reserve}, // reserve + arg.src // dst }); - } -} -TEST_F(CUDA, BN_BACKWARD) { - using namespace batch_normalization; - std::vector args = get_args(); - Checker checker(handle_cuda()); - for (auto&& arg : args) { - for (int i = 0; i < 8; ++i) { - checker.set_dtype(i, dtype::Float32()); + // Backward + for (int i = 0; i < 9; ++i) { + checker_bwd.set_dtype(i, dtype::Float32()); } - checker.set_dtype(0, arg.dtype) // x - .set_dtype(1, arg.dtype) // dy - .set_dtype(7, arg.dtype); // dx - checker.set_epsilon(1e-3).set_param(arg.param).exec( + checker_bwd + .set_dtype(0, arg.dtype) // x + .set_dtype(1, arg.dtype) // dy + .set_dtype(5, dtype::Byte()) // reserve + .set_dtype(8, arg.dtype) // dx + .set_bypass(5); + checker_bwd.set_epsilon(1e-3).set_param(arg.param).exec( {arg.src, arg.src, arg.param_shape, arg.param_shape, - arg.param_shape, arg.param_shape, arg.param_shape, arg.src}); + arg.param_shape, {reserve}, arg.param_shape, arg.param_shape, + arg.src}); } } diff --git a/dnn/test/rocm/bn.cpp b/dnn/test/rocm/bn.cpp index 59f8091c..7fa0e593 100644 --- a/dnn/test/rocm/bn.cpp +++ b/dnn/test/rocm/bn.cpp @@ -31,6 +31,7 @@ TEST_F(ROCM, BN_FORWARD) { checker.set_dtype(i, dtype::Float32()); } checker.set_dtype(0, arg.dtype); + checker.set_dtype(8, arg.dtype); checker.set_epsilon(1e-3).set_param(arg.param); for (bool need_statistic : {false, true}) checker.exec({ @@ -43,7 +44,8 @@ TEST_F(ROCM, BN_FORWARD) { : TensorShape({0}), // variance arg.param_shape, // batch_mean arg.param_shape, // batch_inv_variance - {} // dst + {0}, // reserve + arg.src // dst }); } } @@ -53,15 +55,16 @@ TEST_F(ROCM, BN_BACKWARD) { std::vector args = get_args(); Checker checker(handle_rocm()); for (auto&& arg : args) { - for (int i = 0; i < 8; ++i) { + for (int i = 0; i < 9; ++i) { checker.set_dtype(i, dtype::Float32()); } checker.set_dtype(0, arg.dtype) // x .set_dtype(1, arg.dtype) // dy - .set_dtype(7, arg.dtype); // dx + .set_dtype(8, arg.dtype); // dx checker.set_epsilon(1e-3).set_param(arg.param).exec( {arg.src, arg.src, arg.param_shape, arg.param_shape, - arg.param_shape, arg.param_shape, arg.param_shape, arg.src}); + arg.param_shape, {0}, arg.param_shape, arg.param_shape, + arg.src}); } }