|
- #pragma once
- #include "megdnn/basic_types.h"
- #include "megdnn/opr_param_defs.h"
-
- namespace megdnn {
- namespace test {
- namespace batch_normalization {
-
- struct TestArg {
- param::BN param;
- TensorShape src, param_shape;
- DType dtype;
- TestArg(param::BN param, TensorShape src, TensorShape param_shape, DType dtype)
- : param(param), src(src), param_shape(param_shape), dtype(dtype) {}
- };
-
- std::vector<TestArg> get_args() {
- std::vector<TestArg> args;
- // Case 1
- // ParamDim: 1 x 1 x H x W
- // N = 3, C = 3
- for (size_t i = 4; i < 257; i *= 4) {
- param::BN param;
- param.fwd_mode = param::BN::FwdMode::TRAINING;
- param.param_dim = param::BN::ParamDim::DIM_11HW;
- param.avg_factor = 1.f;
- args.emplace_back(
- param, TensorShape{2, 3, i, i}, TensorShape{1, 1, i, i},
- dtype::Float32());
- args.emplace_back(
- param, TensorShape{2, 3, i, i}, TensorShape{1, 1, i, i},
- dtype::Float16());
- }
-
- // case 2: 1 x C x 1 x 1
-
- for (size_t i = 4; i < 257; i *= 4) {
- param::BN param;
- param.fwd_mode = param::BN::FwdMode::TRAINING;
- param.param_dim = param::BN::ParamDim::DIM_1C11;
- args.emplace_back(
- param, TensorShape{3, 3, i, i}, TensorShape{1, 3, 1, 1},
- dtype::Float32());
- args.emplace_back(
- param, TensorShape{3, 3, i, i}, TensorShape{1, 3, 1, 1},
- dtype::Float16());
- }
-
- // case 3: 1 x 1 x 1 x C
-
- for (size_t i = 4; i < 257; i *= 4) {
- param::BN param;
- param.fwd_mode = param::BN::FwdMode::TRAINING;
- param.param_dim = param::BN::ParamDim::DIM_111C;
- args.emplace_back(
- param, TensorShape{3, i, i, 3}, TensorShape{1, 1, 1, 3},
- dtype::Float32());
- args.emplace_back(
- param, TensorShape{3, i, i, 3}, TensorShape{1, 1, 1, 3},
- dtype::Float16());
- }
-
- return args;
- }
-
- } // namespace batch_normalization
- } // namespace test
- } // namespace megdnn
-
- // vim: syntax=cpp.doxygen
|