You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pooling.cpp 3.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. #include "test/cpu/fixture.h"
  2. #include "test/common/benchmarker.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/pooling.h"
  5. namespace megdnn {
  6. namespace test {
  7. TEST_F(CPU, POOLING) {
  8. auto args = pooling::get_args();
  9. using Format = param::Pooling::Format;
  10. for (auto dtype : std::vector<DType>{dtype::Int8(), dtype::Float32()})
  11. for (Format format : {Format::NCHW, Format::NHWC})
  12. for (auto&& arg : args) {
  13. auto param = arg.param;
  14. auto src = arg.ishape;
  15. Checker<Pooling> checker(handle());
  16. param.format = format;
  17. if (param.format == Format::NHWC) {
  18. src = cvt_src_or_dst_nchw2nhwc(src);
  19. }
  20. checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).exec(
  21. TensorShapeArray{src, {}});
  22. }
  23. }
  24. TEST_F(CPU, POOLING_INT) {
  25. UniformIntRNG rng(0, 255);
  26. for (int modeflag = 0; modeflag < 2; ++modeflag) {
  27. param::Pooling param;
  28. param.mode =
  29. modeflag ? param::Pooling::Mode::AVERAGE : param::Pooling::Mode::MAX;
  30. param.window_h = param.window_w = 2;
  31. param.stride_h = param.stride_w = 2;
  32. param.pad_h = param.pad_w = 0;
  33. std::vector<size_t> sizes = {10, 12, 13, 15, 20, 63};
  34. for (size_t ih : sizes)
  35. for (size_t iw : sizes) {
  36. Checker<Pooling> checker(handle());
  37. checker.set_rng(0, &rng);
  38. checker.set_rng(1, &rng);
  39. checker.set_rng(2, &rng);
  40. checker.set_dtype(0, dtype::Int8());
  41. checker.set_dtype(1, dtype::Int8());
  42. checker.set_param(param).exec(TensorShapeArray{{2, 3, ih, iw}, {}});
  43. }
  44. }
  45. }
  46. #if MEGDNN_WITH_BENCHMARK
  47. TEST_F(CPU, BENCHMARK_POOLING_INT) {
  48. UniformIntRNG rng(0, 255);
  49. for (int modeflag = 0; modeflag < 2; ++modeflag) {
  50. param::Pooling param;
  51. if (modeflag) {
  52. param.mode = param::Pooling::Mode::MAX;
  53. std::cout << "mode=max" << std::endl;
  54. } else {
  55. param.mode = param::Pooling::Mode::AVERAGE;
  56. std::cout << "mode=avg" << std::endl;
  57. }
  58. param.window_h = param.window_w = 2;
  59. param.stride_h = param.stride_w = 2;
  60. param.pad_h = param.pad_w = 0;
  61. float time_int, time_float;
  62. {
  63. std::cout << "int: ";
  64. Benchmarker<Pooling> benchmarker(handle());
  65. benchmarker.set_dtype(0, dtype::Int8());
  66. benchmarker.set_dtype(1, dtype::Int8());
  67. benchmarker.set_rng(0, &rng);
  68. benchmarker.set_rng(1, &rng);
  69. time_int = benchmarker.set_param(param).exec({{2, 3, 640, 480}, {}});
  70. }
  71. {
  72. std::cout << "float: ";
  73. Benchmarker<Pooling> benchmarker(handle());
  74. time_float = benchmarker.set_param(param).exec({{2, 3, 640, 480}, {}});
  75. }
  76. printf("time: int=%.3fms float=%.3fms\n", time_int, time_float);
  77. }
  78. }
  79. #endif
  80. } // namespace test
  81. } // namespace megdnn
  82. // vim: syntax=cpp.doxygen