You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mask_conv.h 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. #include "megdnn/oprs.h"
  2. #include "test/common/benchmarker.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/rng.h"
  5. #pragma once
  6. namespace {
  7. using namespace megdnn;
  8. using namespace test;
  9. std::vector<std::vector<int>> get_args() {
  10. std::vector<std::vector<int>> args;
  11. args.push_back({2, 1, 1, 5, 5, 3, 3, 1, 1, 0, 0, 1, 1});
  12. args.push_back({1, 2, 3, 24, 24, 3, 5, 1, 1, 0, 0, 1, 1});
  13. args.push_back({20, 3, 4, 24, 21, 5, 3, 1, 1, 0, 0, 1, 1});
  14. args.push_back({20, 3, 4, 24, 21, 5, 3, 2, 2, 0, 0, 1, 1});
  15. args.push_back({20, 3, 4, 24, 21, 5, 3, 2, 2, 2, 2, 1, 1});
  16. args.push_back({20, 3, 4, 24, 21, 5, 3, 2, 2, 1, 2, 1, 1});
  17. args.push_back({20, 3, 4, 24, 21, 5, 3, 2, 2, 1, 2, 2, 3});
  18. args.push_back({20, 3, 4, 24, 21, 5, 3, 2, 2, 1, 2, 3, 2});
  19. args.push_back({2, 108, 108, 14, 14, 3, 3, 1, 1, 0, 0, 1, 1});
  20. args.push_back({2, 108, 108, 14, 14, 3, 3, 1, 1, 2, 2, 1, 1});
  21. args.push_back({2, 108, 108, 14, 14, 3, 3, 2, 2, 2, 2, 1, 1});
  22. args.push_back({2, 108, 108, 14, 14, 3, 3, 2, 2, 0, 0, 1, 1});
  23. args.push_back({2, 3, 3, 224, 224, 3, 3, 1, 1, 0, 0, 1, 1});
  24. args.push_back({2, 3, 3, 224, 224, 3, 3, 2, 2, 0, 0, 1, 1});
  25. return args;
  26. }
  27. void mask_conv_test(Handle* handle) {
  28. auto run = [&](size_t N, size_t IC, size_t OC, size_t IH, size_t IW, size_t FH,
  29. size_t FW, size_t SH, size_t SW, size_t PH, size_t PW, size_t DH,
  30. size_t DW) {
  31. size_t OH = (IH + 2 * PH - ((FH - 1) * DH + 1)) / SH + 1;
  32. size_t OW = (IW + 2 * PW - ((FW - 1) * DW + 1)) / SW + 1;
  33. Checker<MaskConvolution> checker(handle);
  34. using Param = param::Convolution;
  35. Param param(
  36. Param::Mode::CROSS_CORRELATION,
  37. // pad
  38. PH, PW,
  39. // stride
  40. SH, SW,
  41. // dilate
  42. DH, DW, Param::Sparse::DENSE, Param::Format::NCHW);
  43. TensorShape src_shape({N, IC, IH, IW}), filter_shape({OC, IC, FH, FW}),
  44. mask({OH, OW}), dst({});
  45. auto rng = std::make_unique<BernoulliRNG>(0.5);
  46. checker.set_param(param);
  47. checker.set_dtype(2, dtype::Int8()).execs({src_shape, filter_shape, mask, dst});
  48. checker.set_dtype(2, dtype::Int16())
  49. .execs({src_shape, filter_shape, mask, dst});
  50. checker.set_dtype(2, dtype::Int32())
  51. .execs({src_shape, filter_shape, mask, dst});
  52. };
  53. auto test_args = get_args();
  54. for (auto&& arg : test_args) {
  55. run(arg[0], arg[1], arg[2], arg[3], arg[4], arg[5], arg[6], arg[7], arg[8],
  56. arg[9], arg[10], arg[11], arg[12]);
  57. }
  58. }
  59. #if MEGDNN_WITH_BENCHMARK
  60. void mask_conv_benchmark(Handle* handle) {
  61. auto benchmark = [&](size_t N, size_t IC, size_t OC, size_t IH, size_t IW,
  62. size_t FH, size_t FW, size_t SH, size_t SW, size_t PH,
  63. size_t PW, size_t DH, size_t DW) {
  64. size_t OH = (IH + 2 * PH - ((FH - 1) * DH + 1)) / SH + 1;
  65. size_t OW = (IW + 2 * PW - ((FW - 1) * DW + 1)) / SW + 1;
  66. Benchmarker<MaskConvolution> benchmark_fallback(handle);
  67. Benchmarker<Convolution> benchmark_naive(handle);
  68. using Param = param::Convolution;
  69. Param param(
  70. Param::Mode::CROSS_CORRELATION,
  71. // pad
  72. PH, PW,
  73. // stride
  74. SH, SW,
  75. // dilate
  76. DH, DW, Param::Sparse::DENSE, Param::Format::NCHW);
  77. TensorShape src_shape({N, IC, IH, IW}), filter_shape({OC, IC, FH, FW}),
  78. mask({OH, OW}), dst({});
  79. benchmark_fallback.set_param(param).set_dtype(2, dtype::Int32()).set_times(20);
  80. printf("Execing mask conv: \n");
  81. #define test(p) \
  82. benchmark_fallback.set_rng(2, new BernoulliRNG(p)) \
  83. .execs({src_shape, filter_shape, mask, dst})
  84. for (auto p : {0.1, 0.2, 0.3, 0.4, 0.5, 0.99})
  85. test(p);
  86. printf("Execing normal conv: \n");
  87. benchmark_naive.set_param(param).set_times(20).execs(
  88. {src_shape, filter_shape, dst});
  89. #undef test
  90. };
  91. auto test_args = get_args();
  92. for (auto&& arg : test_args) {
  93. benchmark(
  94. arg[0], arg[1], arg[2], arg[3], arg[4], arg[5], arg[6], arg[7], arg[8],
  95. arg[9], arg[10], arg[11], arg[12]);
  96. }
  97. }
  98. #endif
  99. } // namespace