You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. /**
  2. * \file dnn/test/common/bn.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/basic_types.h"
  13. #include "megdnn/opr_param_defs.h"
  14. namespace megdnn {
  15. namespace test {
  16. namespace batch_normalization {
  17. struct TestArg {
  18. param::BN param;
  19. TensorShape src, param_shape;
  20. DType dtype;
  21. TestArg(param::BN param, TensorShape src, TensorShape param_shape,
  22. DType dtype)
  23. : param(param), src(src), param_shape(param_shape), dtype(dtype) {}
  24. };
  25. std::vector<TestArg> get_args() {
  26. std::vector<TestArg> args;
  27. // Case 1
  28. // ParamDim: 1 x 1 x H x W
  29. // N = 3, C = 3
  30. for (size_t i = 4; i < 257; i *= 4) {
  31. param::BN param;
  32. param.fwd_mode = param::BN::FwdMode::TRAINING;
  33. param.param_dim = param::BN::ParamDim::DIM_11HW;
  34. param.avg_factor = 1.f;
  35. args.emplace_back(param, TensorShape{2, 3, i, i},
  36. TensorShape{1, 1, i, i}, dtype::Float32());
  37. args.emplace_back(param, TensorShape{2, 3, i, i},
  38. TensorShape{1, 1, i, i}, dtype::Float16());
  39. }
  40. // case 2: 1 x C x 1 x 1
  41. for (size_t i = 4; i < 257; i *= 4) {
  42. param::BN param;
  43. param.fwd_mode = param::BN::FwdMode::TRAINING;
  44. param.param_dim = param::BN::ParamDim::DIM_1C11;
  45. args.emplace_back(param, TensorShape{3, 3, i, i},
  46. TensorShape{1, 3, 1, 1}, dtype::Float32());
  47. args.emplace_back(param, TensorShape{3, 3, i, i},
  48. TensorShape{1, 3, 1, 1}, dtype::Float16());
  49. }
  50. return args;
  51. }
  52. } // namespace batch_normalization
  53. } // namespace test
  54. } // namespace megdnn
  55. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台