Browse Source

test(dnn/bn): add compatible configs for NHWC BN

GitOrigin-RevId: ac757ca307
release-1.7
Megvii Engine Team 3 years ago
parent
commit
3d3666b6e0
8 changed files with 89 additions and 26 deletions
  1. +13
    -1
      dnn/test/common/bn.h
  2. +1
    -1
      dnn/test/common/checker.cpp
  3. +5
    -0
      dnn/test/common/checker.h
  4. +9
    -0
      dnn/test/common/deduce_layout_proxy.h
  5. +17
    -0
      dnn/test/common/exec_proxy.h
  6. +4
    -0
      dnn/test/common/rng.cpp
  7. +33
    -20
      dnn/test/cuda/bn.cpp
  8. +7
    -4
      dnn/test/rocm/bn.cpp

+ 13
- 1
dnn/test/common/bn.h View File

@@ -53,6 +53,18 @@ std::vector<TestArg> get_args() {
TensorShape{1, 3, 1, 1}, dtype::Float16()); TensorShape{1, 3, 1, 1}, dtype::Float16());
} }


// case 3: 1 x 1 x 1 x C

for (size_t i = 4; i < 257; i *= 4) {
param::BN param;
param.fwd_mode = param::BN::FwdMode::TRAINING;
param.param_dim = param::BN::ParamDim::DIM_111C;
args.emplace_back(param, TensorShape{3, i, i, 3},
TensorShape{1, 1, 1, 3}, dtype::Float32());
args.emplace_back(param, TensorShape{3, i, i, 3},
TensorShape{1, 1, 1, 3}, dtype::Float16());
}

return args; return args;
} }


@@ -60,4 +72,4 @@ std::vector<TestArg> get_args() {
} // namespace test } // namespace test
} // namespace megdnn } // namespace megdnn


// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 1
- 1
dnn/test/common/checker.cpp View File

@@ -419,7 +419,7 @@ void CheckerHelper::copy_tensors_from_device(const TensorValueArray& dest,
void CheckerHelper::check_tensors(const TensorValueArray& expected, void CheckerHelper::check_tensors(const TensorValueArray& expected,
const TensorValueArray& computed) { const TensorValueArray& computed) {
for (size_t i = 0; i < expected.size(); ++i) { for (size_t i = 0; i < expected.size(); ++i) {
if (expected[i].layout.ndim == 0)
if (expected[i].layout.ndim == 0 || m_bypass.find(i) != m_bypass.end())
continue; continue;
if (m_allow_invalid_check) { if (m_allow_invalid_check) {
MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG_ALLOW_INVALID( MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG_ALLOW_INVALID(


+ 5
- 0
dnn/test/common/checker.h View File

@@ -69,6 +69,7 @@ protected:
std::unordered_map<size_t, RNG*> m_rng; std::unordered_map<size_t, RNG*> m_rng;
std::unordered_map<size_t, DType> m_dtype; std::unordered_map<size_t, DType> m_dtype;
std::unordered_map<size_t, TensorFormat> m_fmt; std::unordered_map<size_t, TensorFormat> m_fmt;
std::set<size_t> m_bypass;
float_t m_epsilon = 1e-3, m_max_avg_error = 1e-3, float_t m_epsilon = 1e-3, m_max_avg_error = 1e-3,
m_max_avg_biased_error = 1e-3; m_max_avg_biased_error = 1e-3;
float_t m_perf_check_threshold = -1; float_t m_perf_check_threshold = -1;
@@ -184,6 +185,10 @@ public:
m_rng[idx] = rng; m_rng[idx] = rng;
return *this; return *this;
} }
Checker& set_bypass(size_t idx) {
m_bypass.insert(idx);
return *this;
}
//! max error of a single element //! max error of a single element
Checker& set_epsilon(dt_float32 epsilon) { Checker& set_epsilon(dt_float32 epsilon) {
m_epsilon = epsilon; m_epsilon = epsilon;


+ 9
- 0
dnn/test/common/deduce_layout_proxy.h View File

@@ -82,6 +82,15 @@ struct DeduceLayoutProxy<Opr, 8, true> {
} }
}; };


template <typename Opr>
struct DeduceLayoutProxy<Opr, 9, true> {
static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
megdnn_assert(layouts.size() == 9);
opr->deduce_layout(layouts[0], layouts[1], layouts[2], layouts[3],
layouts[4], layouts[5], layouts[6], layouts[7],
layouts[8]);
}
};
} // namespace test } // namespace test
} // namespace megdnn } // namespace megdnn




+ 17
- 0
dnn/test/common/exec_proxy.h View File

@@ -23,6 +23,23 @@ template <typename Opr, size_t Arity, bool has_workspace>
struct ExecProxy; struct ExecProxy;


template <typename Opr> template <typename Opr>
struct ExecProxy<Opr, 9, true> {
WorkspaceWrapper W;
void exec(Opr* opr, const TensorNDArray& tensors) {
if (!W.valid()) {
W = WorkspaceWrapper(opr->handle(), 0);
}
W.update(opr->get_workspace_in_bytes(
tensors[0].layout, tensors[1].layout, tensors[2].layout,
tensors[3].layout, tensors[4].layout, tensors[5].layout,
tensors[6].layout, tensors[7].layout, tensors[8].layout));
opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
tensors[5], tensors[6], tensors[7], tensors[8],
W.workspace());
}
};

template <typename Opr>
struct ExecProxy<Opr, 8, true> { struct ExecProxy<Opr, 8, true> {
WorkspaceWrapper W; WorkspaceWrapper W;
void exec(Opr* opr, const TensorNDArray& tensors) { void exec(Opr* opr, const TensorNDArray& tensors) {


+ 4
- 0
dnn/test/common/rng.cpp View File

@@ -211,6 +211,10 @@ void IIDRNG::gen(const TensorND& tensor) {
} }
return; return;
} }
if (tensor.layout.dtype.enumv() == DTypeEnum::Byte) {
memset(tensor.raw_ptr, 0, tensor.layout.access_bytes());
return;
}
megdnn_assert(0, "IIDRNG does not know how to generate value for DType %s", megdnn_assert(0, "IIDRNG does not know how to generate value for DType %s",
tensor.layout.dtype.name()); tensor.layout.dtype.name());
} }


+ 33
- 20
dnn/test/cuda/bn.cpp View File

@@ -6,10 +6,13 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "test/cuda/fixture.h" #include "test/cuda/fixture.h"


#include "src/cuda/batch_normalization/opr_impl.h"
#include "src/cuda/utils.h"
#include "megdnn/opr_param_defs.h" #include "megdnn/opr_param_defs.h"
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "test/common/bn.h" #include "test/common/bn.h"
@@ -21,15 +24,26 @@
namespace megdnn { namespace megdnn {
namespace test { namespace test {


TEST_F(CUDA, BN_FORWARD) {
TEST_F(CUDA, BN_FORWARD_BACKWARD) {
using namespace batch_normalization; using namespace batch_normalization;
using cuda::cudnn_handle;
using cuda::batch_normalization::BNTensorDescHolder;
using cuda::batch_normalization::get_reserve_size;
std::vector<TestArg> args = get_args(); std::vector<TestArg> args = get_args();
Checker<BNForward> checker(handle_cuda()); Checker<BNForward> checker(handle_cuda());
Checker<BNBackward> checker_bwd(handle_cuda());
for (auto&& arg : args) { for (auto&& arg : args) {
for (int i = 0; i < 8; ++i) {
auto tensor_desc = BNTensorDescHolder({arg.src, arg.dtype}, arg.param.param_dim,
arg.param.fwd_mode);
auto reserve = get_reserve_size(cudnn_handle(handle_cuda()), tensor_desc);
// Forward
for (int i = 0; i < 9; ++i) {
checker.set_dtype(i, dtype::Float32()); checker.set_dtype(i, dtype::Float32());
} }
checker.set_dtype(0, arg.dtype); checker.set_dtype(0, arg.dtype);
checker.set_dtype(7, dtype::Byte());
checker.set_dtype(8, arg.dtype);
checker.set_bypass(7);
checker.set_epsilon(1e-3).set_param(arg.param); checker.set_epsilon(1e-3).set_param(arg.param);
for (bool need_statistic : {false, true}) for (bool need_statistic : {false, true})
checker.exec({ checker.exec({
@@ -40,27 +54,26 @@ TEST_F(CUDA, BN_FORWARD) {
: TensorShape({0}), // mean : TensorShape({0}), // mean
need_statistic ? arg.param_shape need_statistic ? arg.param_shape
: TensorShape({0}), // variance : TensorShape({0}), // variance
arg.param_shape, // batch_mean
arg.param_shape, // batch_inv_variance
{} // dst
arg.param_shape, // batch_mean
arg.param_shape, // batch_inv_variance
{reserve}, // reserve
arg.src // dst
}); });
}
}


TEST_F(CUDA, BN_BACKWARD) {
using namespace batch_normalization;
std::vector<TestArg> args = get_args();
Checker<BNBackward> checker(handle_cuda());
for (auto&& arg : args) {
for (int i = 0; i < 8; ++i) {
checker.set_dtype(i, dtype::Float32());
// Backward
for (int i = 0; i < 9; ++i) {
checker_bwd.set_dtype(i, dtype::Float32());
} }
checker.set_dtype(0, arg.dtype) // x
.set_dtype(1, arg.dtype) // dy
.set_dtype(7, arg.dtype); // dx
checker.set_epsilon(1e-3).set_param(arg.param).exec(
checker_bwd
.set_dtype(0, arg.dtype) // x
.set_dtype(1, arg.dtype) // dy
.set_dtype(5, dtype::Byte()) // reserve
.set_dtype(8, arg.dtype) // dx
.set_bypass(5);
checker_bwd.set_epsilon(1e-3).set_param(arg.param).exec(
{arg.src, arg.src, arg.param_shape, arg.param_shape, {arg.src, arg.src, arg.param_shape, arg.param_shape,
arg.param_shape, arg.param_shape, arg.param_shape, arg.src});
arg.param_shape, {reserve}, arg.param_shape, arg.param_shape,
arg.src});
} }
} }




+ 7
- 4
dnn/test/rocm/bn.cpp View File

@@ -31,6 +31,7 @@ TEST_F(ROCM, BN_FORWARD) {
checker.set_dtype(i, dtype::Float32()); checker.set_dtype(i, dtype::Float32());
} }
checker.set_dtype(0, arg.dtype); checker.set_dtype(0, arg.dtype);
checker.set_dtype(8, arg.dtype);
checker.set_epsilon(1e-3).set_param(arg.param); checker.set_epsilon(1e-3).set_param(arg.param);
for (bool need_statistic : {false, true}) for (bool need_statistic : {false, true})
checker.exec({ checker.exec({
@@ -43,7 +44,8 @@ TEST_F(ROCM, BN_FORWARD) {
: TensorShape({0}), // variance : TensorShape({0}), // variance
arg.param_shape, // batch_mean arg.param_shape, // batch_mean
arg.param_shape, // batch_inv_variance arg.param_shape, // batch_inv_variance
{} // dst
{0}, // reserve
arg.src // dst
}); });
} }
} }
@@ -53,15 +55,16 @@ TEST_F(ROCM, BN_BACKWARD) {
std::vector<TestArg> args = get_args(); std::vector<TestArg> args = get_args();
Checker<BNBackward> checker(handle_rocm()); Checker<BNBackward> checker(handle_rocm());
for (auto&& arg : args) { for (auto&& arg : args) {
for (int i = 0; i < 8; ++i) {
for (int i = 0; i < 9; ++i) {
checker.set_dtype(i, dtype::Float32()); checker.set_dtype(i, dtype::Float32());
} }
checker.set_dtype(0, arg.dtype) // x checker.set_dtype(0, arg.dtype) // x
.set_dtype(1, arg.dtype) // dy .set_dtype(1, arg.dtype) // dy
.set_dtype(7, arg.dtype); // dx
.set_dtype(8, arg.dtype); // dx
checker.set_epsilon(1e-3).set_param(arg.param).exec( checker.set_epsilon(1e-3).set_param(arg.param).exec(
{arg.src, arg.src, arg.param_shape, arg.param_shape, {arg.src, arg.src, arg.param_shape, arg.param_shape,
arg.param_shape, arg.param_shape, arg.param_shape, arg.src});
arg.param_shape, {0}, arg.param_shape, arg.param_shape,
arg.src});
} }
} }




Loading…
Cancel
Save