You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_pooling.cpp 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. #include "test/common/conv_pooling.h"
  2. namespace megdnn {
  3. namespace test {
  4. namespace conv_pooling {
  5. /* ConvPooling(
  6. Method method_=Method::WITH_TEXTURE_OBJ,
  7. ConvMode convMode_=ConvMode::CROSS_CORRELATION,
  8. PoolMode poolMode_=PoolMode::AVERAGE,
  9. NonlineMode nonlineMode_=NonlineMode::IDENTITY,
  10. uint32_t pool_shape_h_=1,
  11. uint32_t pool_shape_w_=1,
  12. uint32_t pool_stride_h_=1,
  13. uint32_t pool_stride_w_=1,
  14. uint32_t pool_pad_h_=0,
  15. uint32_t pool_pad_w_=0,
  16. uint32_t conv_stride_h_=1,
  17. uint32_t conv_stride_w_=1,
  18. uint32_t conv_pad_h_=0,
  19. uint32_t conv_pad_w_=0,
  20. float *bias_=NULL)
  21. */
  22. std::vector<TestArg> get_args() {
  23. std::vector<TestArg> args;
  24. uint32_t pool_shape_h = 3;
  25. uint32_t pool_shape_w = 3;
  26. uint32_t pool_stride_h = pool_shape_h;
  27. uint32_t pool_stride_w = pool_shape_w;
  28. param::ConvPooling cur_param(
  29. param::ConvPooling::Method::WITH_TEXTURE_OBJ,
  30. param::ConvPooling::ConvMode::CONVOLUTION,
  31. param::ConvPooling::PoolMode::MAX, param::ConvPooling::NonlineMode::RELU,
  32. pool_shape_h, pool_shape_w, pool_stride_h, pool_stride_w, 0, 0, 1, 1, 0, 0);
  33. std::vector<param::ConvPooling::ConvMode> conv_mode;
  34. conv_mode.push_back(param::ConvPooling::ConvMode::CONVOLUTION);
  35. conv_mode.push_back(param::ConvPooling::ConvMode::CROSS_CORRELATION);
  36. std::vector<param::ConvPooling::NonlineMode> nonline_mode;
  37. nonline_mode.push_back(param::ConvPooling::NonlineMode::IDENTITY);
  38. nonline_mode.push_back(param::ConvPooling::NonlineMode::SIGMOID);
  39. nonline_mode.push_back(param::ConvPooling::NonlineMode::RELU);
  40. for (size_t i = 19; i < 21; ++i) {
  41. for (size_t i_nl_mode = 0; i_nl_mode < nonline_mode.size(); ++i_nl_mode) {
  42. cur_param.nonlineMode = nonline_mode[i_nl_mode];
  43. for (size_t i_conv_mode = 0; i_conv_mode < conv_mode.size();
  44. ++i_conv_mode) {
  45. for (size_t kernel_size = 1; kernel_size < 7; ++kernel_size) {
  46. for (size_t pool_size = 1; pool_size < 5; ++pool_size) {
  47. if (pool_size >= kernel_size)
  48. continue;
  49. cur_param.convMode = conv_mode[i_conv_mode];
  50. args.emplace_back(
  51. cur_param, TensorShape{20, 4, i, i},
  52. TensorShape{3, 4, 4, 4}, TensorShape{1, 3, 1, 1});
  53. }
  54. }
  55. }
  56. }
  57. }
  58. /*
  59. // large channel
  60. for (size_t i = 20; i < 22; ++i) {
  61. cur_param.convMode = param::ConvPooling::ConvMode::CONVOLUTION;
  62. args.emplace_back(cur_param,
  63. TensorShape{2, 20, i, i+1},
  64. TensorShape{30, 20, 4, 4},
  65. TensorShape{1, 30, 1, 1});
  66. cur_param.convMode = param::ConvPooling::ConvMode::CROSS_CORRELATION;
  67. args.emplace_back(cur_param,
  68. TensorShape{2, 20, i, i+1},
  69. TensorShape{30, 20, 3, 3},
  70. TensorShape{1, 30, 1, 1});
  71. }
  72. // large filter
  73. for (size_t i = 20; i < 22; ++i) {
  74. cur_param.convMode = param::ConvPooling::ConvMode::CONVOLUTION;
  75. args.emplace_back(cur_param,
  76. TensorShape{2, 2, i, i+1},
  77. TensorShape{3, 2, 5, 5},
  78. TensorShape{1, 3, 1, 1});
  79. cur_param.convMode = param::ConvPooling::ConvMode::CROSS_CORRELATION;
  80. cur_param.convMode =
  81. param::ConvPooling::ConvMode::CROSS_CORRELATION; args.emplace_back(cur_param,
  82. TensorShape{2, 2, i, i+1},
  83. TensorShape{3, 2, 5, 5},
  84. TensorShape{1, 3, 1, 1});
  85. }
  86. */
  87. return args;
  88. }
  89. } // namespace conv_pooling
  90. } // namespace test
  91. } // namespace megdnn
  92. // vim: syntax=cpp.doxygen