You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bn.h 2.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. /**
  2. * \file dnn/test/common/bn.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/basic_types.h"
  13. #include "megdnn/opr_param_defs.h"
  14. namespace megdnn {
  15. namespace test {
  16. namespace batch_normalization {
  17. struct TestArg {
  18. param::BN param;
  19. TensorShape src, param_shape;
  20. DType dtype;
  21. TestArg(param::BN param, TensorShape src, TensorShape param_shape, DType dtype)
  22. : param(param), src(src), param_shape(param_shape), dtype(dtype) {}
  23. };
  24. std::vector<TestArg> get_args() {
  25. std::vector<TestArg> args;
  26. // Case 1
  27. // ParamDim: 1 x 1 x H x W
  28. // N = 3, C = 3
  29. for (size_t i = 4; i < 257; i *= 4) {
  30. param::BN param;
  31. param.fwd_mode = param::BN::FwdMode::TRAINING;
  32. param.param_dim = param::BN::ParamDim::DIM_11HW;
  33. param.avg_factor = 1.f;
  34. args.emplace_back(
  35. param, TensorShape{2, 3, i, i}, TensorShape{1, 1, i, i},
  36. dtype::Float32());
  37. args.emplace_back(
  38. param, TensorShape{2, 3, i, i}, TensorShape{1, 1, i, i},
  39. dtype::Float16());
  40. }
  41. // case 2: 1 x C x 1 x 1
  42. for (size_t i = 4; i < 257; i *= 4) {
  43. param::BN param;
  44. param.fwd_mode = param::BN::FwdMode::TRAINING;
  45. param.param_dim = param::BN::ParamDim::DIM_1C11;
  46. args.emplace_back(
  47. param, TensorShape{3, 3, i, i}, TensorShape{1, 3, 1, 1},
  48. dtype::Float32());
  49. args.emplace_back(
  50. param, TensorShape{3, 3, i, i}, TensorShape{1, 3, 1, 1},
  51. dtype::Float16());
  52. }
  53. // case 3: 1 x 1 x 1 x C
  54. for (size_t i = 4; i < 257; i *= 4) {
  55. param::BN param;
  56. param.fwd_mode = param::BN::FwdMode::TRAINING;
  57. param.param_dim = param::BN::ParamDim::DIM_111C;
  58. args.emplace_back(
  59. param, TensorShape{3, i, i, 3}, TensorShape{1, 1, 1, 3},
  60. dtype::Float32());
  61. args.emplace_back(
  62. param, TensorShape{3, i, i, 3}, TensorShape{1, 1, 1, 3},
  63. dtype::Float16());
  64. }
  65. return args;
  66. }
  67. } // namespace batch_normalization
  68. } // namespace test
  69. } // namespace megdnn
  70. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台