You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_pooling.cpp 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. /**
  2. * \file dnn/test/common/conv_pooling.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/conv_pooling.h"
  12. namespace megdnn {
  13. namespace test {
  14. namespace conv_pooling {
  15. /* ConvPooling(
  16. Method method_=Method::WITH_TEXTURE_OBJ,
  17. ConvMode convMode_=ConvMode::CROSS_CORRELATION,
  18. PoolMode poolMode_=PoolMode::AVERAGE,
  19. NonlineMode nonlineMode_=NonlineMode::IDENTITY,
  20. uint32_t pool_shape_h_=1,
  21. uint32_t pool_shape_w_=1,
  22. uint32_t pool_stride_h_=1,
  23. uint32_t pool_stride_w_=1,
  24. uint32_t pool_pad_h_=0,
  25. uint32_t pool_pad_w_=0,
  26. uint32_t conv_stride_h_=1,
  27. uint32_t conv_stride_w_=1,
  28. uint32_t conv_pad_h_=0,
  29. uint32_t conv_pad_w_=0,
  30. float *bias_=NULL)
  31. */
  32. std::vector<TestArg> get_args()
  33. {
  34. std::vector<TestArg> args;
  35. uint32_t pool_shape_h = 3;
  36. uint32_t pool_shape_w = 3;
  37. uint32_t pool_stride_h = pool_shape_h;
  38. uint32_t pool_stride_w = pool_shape_w;
  39. param::ConvPooling cur_param(
  40. param::ConvPooling::Method::WITH_TEXTURE_OBJ,
  41. param::ConvPooling::ConvMode::CONVOLUTION,
  42. param::ConvPooling::PoolMode::MAX,
  43. param::ConvPooling::NonlineMode::RELU,
  44. pool_shape_h, pool_shape_w,
  45. pool_stride_h, pool_stride_w,
  46. 0, 0, 1, 1, 0, 0
  47. );
  48. std::vector<param::ConvPooling::ConvMode> conv_mode;
  49. conv_mode.push_back(param::ConvPooling::ConvMode::CONVOLUTION);
  50. conv_mode.push_back(param::ConvPooling::ConvMode::CROSS_CORRELATION);
  51. std::vector<param::ConvPooling::NonlineMode> nonline_mode;
  52. nonline_mode.push_back(param::ConvPooling::NonlineMode::IDENTITY);
  53. nonline_mode.push_back(param::ConvPooling::NonlineMode::SIGMOID);
  54. nonline_mode.push_back(param::ConvPooling::NonlineMode::RELU);
  55. for (size_t i = 19; i < 21; ++i) {
  56. for(size_t i_nl_mode = 0; i_nl_mode < nonline_mode.size(); ++ i_nl_mode) {
  57. cur_param.nonlineMode = nonline_mode[i_nl_mode];
  58. for (size_t i_conv_mode = 0; i_conv_mode < conv_mode.size(); ++ i_conv_mode) {
  59. for(size_t kernel_size = 1; kernel_size < 7; ++ kernel_size) {
  60. for(size_t pool_size = 1; pool_size < 5; ++ pool_size) {
  61. if (pool_size >= kernel_size)
  62. continue;
  63. cur_param.convMode = conv_mode[i_conv_mode];
  64. args.emplace_back(cur_param,
  65. TensorShape{20, 4, i, i},
  66. TensorShape{3, 4, 4, 4},
  67. TensorShape{1, 3, 1, 1});
  68. }
  69. }
  70. }
  71. }
  72. }
  73. /*
  74. // large channel
  75. for (size_t i = 20; i < 22; ++i) {
  76. cur_param.convMode = param::ConvPooling::ConvMode::CONVOLUTION;
  77. args.emplace_back(cur_param,
  78. TensorShape{2, 20, i, i+1},
  79. TensorShape{30, 20, 4, 4},
  80. TensorShape{1, 30, 1, 1});
  81. cur_param.convMode = param::ConvPooling::ConvMode::CROSS_CORRELATION;
  82. args.emplace_back(cur_param,
  83. TensorShape{2, 20, i, i+1},
  84. TensorShape{30, 20, 3, 3},
  85. TensorShape{1, 30, 1, 1});
  86. }
  87. // large filter
  88. for (size_t i = 20; i < 22; ++i) {
  89. cur_param.convMode = param::ConvPooling::ConvMode::CONVOLUTION;
  90. args.emplace_back(cur_param,
  91. TensorShape{2, 2, i, i+1},
  92. TensorShape{3, 2, 5, 5},
  93. TensorShape{1, 3, 1, 1});
  94. cur_param.convMode = param::ConvPooling::ConvMode::CROSS_CORRELATION;
  95. cur_param.convMode = param::ConvPooling::ConvMode::CROSS_CORRELATION;
  96. args.emplace_back(cur_param,
  97. TensorShape{2, 2, i, i+1},
  98. TensorShape{3, 2, 5, 5},
  99. TensorShape{1, 3, 1, 1});
  100. }
  101. */
  102. return args;
  103. }
  104. } // namespace conv_pooling
  105. } // namespace test
  106. } // namespace megdnn
  107. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台