You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bn.h 2.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. #pragma once
  2. #include "megdnn/basic_types.h"
  3. #include "megdnn/opr_param_defs.h"
  4. namespace megdnn {
  5. namespace test {
  6. namespace batch_normalization {
  7. struct TestArg {
  8. param::BN param;
  9. TensorShape src, param_shape;
  10. DType dtype;
  11. TestArg(param::BN param, TensorShape src, TensorShape param_shape, DType dtype)
  12. : param(param), src(src), param_shape(param_shape), dtype(dtype) {}
  13. };
  14. std::vector<TestArg> get_args() {
  15. std::vector<TestArg> args;
  16. // Case 1
  17. // ParamDim: 1 x 1 x H x W
  18. // N = 3, C = 3
  19. for (size_t i = 4; i < 257; i *= 4) {
  20. param::BN param;
  21. param.fwd_mode = param::BN::FwdMode::TRAINING;
  22. param.param_dim = param::BN::ParamDim::DIM_11HW;
  23. param.avg_factor = 1.f;
  24. args.emplace_back(
  25. param, TensorShape{2, 3, i, i}, TensorShape{1, 1, i, i},
  26. dtype::Float32());
  27. args.emplace_back(
  28. param, TensorShape{2, 3, i, i}, TensorShape{1, 1, i, i},
  29. dtype::Float16());
  30. }
  31. // case 2: 1 x C x 1 x 1
  32. for (size_t i = 4; i < 257; i *= 4) {
  33. param::BN param;
  34. param.fwd_mode = param::BN::FwdMode::TRAINING;
  35. param.param_dim = param::BN::ParamDim::DIM_1C11;
  36. args.emplace_back(
  37. param, TensorShape{3, 3, i, i}, TensorShape{1, 3, 1, 1},
  38. dtype::Float32());
  39. args.emplace_back(
  40. param, TensorShape{3, 3, i, i}, TensorShape{1, 3, 1, 1},
  41. dtype::Float16());
  42. }
  43. // case 3: 1 x 1 x 1 x C
  44. for (size_t i = 4; i < 257; i *= 4) {
  45. param::BN param;
  46. param.fwd_mode = param::BN::FwdMode::TRAINING;
  47. param.param_dim = param::BN::ParamDim::DIM_111C;
  48. args.emplace_back(
  49. param, TensorShape{3, i, i, 3}, TensorShape{1, 1, 1, 3},
  50. dtype::Float32());
  51. args.emplace_back(
  52. param, TensorShape{3, i, i, 3}, TensorShape{1, 1, 1, 3},
  53. dtype::Float16());
  54. }
  55. return args;
  56. }
  57. } // namespace batch_normalization
  58. } // namespace test
  59. } // namespace megdnn
  60. // vim: syntax=cpp.doxygen